Squashed 'packages/tools/' content from commit 78317b9c

git-subtree-dir: packages/tools
git-subtree-split: 78317b9c127f18bd040c1d77e3c0840cdc9a5b38
This commit is contained in:
Greyson Lalonde
2025-09-12 21:58:02 -04:00
commit e16606672a
303 changed files with 49010 additions and 0 deletions

0
tests/__init__.py Normal file
View File

View File

@@ -0,0 +1,230 @@
from textwrap import dedent
from unittest.mock import MagicMock, patch
import pytest
from mcp import StdioServerParameters
from crewai_tools import MCPServerAdapter
from crewai_tools.adapters.tool_collection import ToolCollection
@pytest.fixture
def echo_server_script():
return dedent(
'''
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo Server")
@mcp.tool()
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"
@mcp.tool()
def calc_tool(a: int, b: int) -> int:
"""Calculate a + b"""
return a + b
mcp.run()
'''
)
@pytest.fixture
def echo_server_sse_script():
return dedent(
'''
from mcp.server.fastmcp import FastMCP
mcp = FastMCP("Echo Server", host="127.0.0.1", port=8000)
@mcp.tool()
def echo_tool(text: str) -> str:
"""Echo the input text"""
return f"Echo: {text}"
@mcp.tool()
def calc_tool(a: int, b: int) -> int:
"""Calculate a + b"""
return a + b
mcp.run("sse")
'''
)
@pytest.fixture
def echo_sse_server(echo_server_sse_script):
import subprocess
import time
# Start the SSE server process with its own process group
process = subprocess.Popen(
["python", "-c", echo_server_sse_script],
)
# Give the server a moment to start up
time.sleep(1)
try:
yield {"url": "http://127.0.0.1:8000/sse"}
finally:
# Clean up the process when test is done
process.kill()
process.wait()
def test_context_manager_syntax(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
with MCPServerAdapter(serverparams) as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 2
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
def test_context_manager_syntax_sse(echo_sse_server):
sse_serverparams = echo_sse_server
with MCPServerAdapter(sse_serverparams) as tools:
assert len(tools) == 2
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
def test_try_finally_syntax(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
try:
mcp_server_adapter = MCPServerAdapter(serverparams)
tools = mcp_server_adapter.tools
assert len(tools) == 2
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
finally:
mcp_server_adapter.stop()
def test_try_finally_syntax_sse(echo_sse_server):
sse_serverparams = echo_sse_server
mcp_server_adapter = MCPServerAdapter(sse_serverparams)
try:
tools = mcp_server_adapter.tools
assert len(tools) == 2
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
assert tools[1].run(a=5, b=3) == '8'
finally:
mcp_server_adapter.stop()
def test_context_manager_with_filtered_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
# Only select the echo_tool
with MCPServerAdapter(serverparams, "echo_tool") as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].run(text="hello") == "Echo: hello"
# Check that calc_tool is not present
with pytest.raises(IndexError):
_ = tools[1]
with pytest.raises(KeyError):
_ = tools["calc_tool"]
def test_context_manager_sse_with_filtered_tools(echo_sse_server):
sse_serverparams = echo_sse_server
# Only select the calc_tool
with MCPServerAdapter(sse_serverparams, "calc_tool") as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 1
assert tools[0].name == "calc_tool"
assert tools[0].run(a=10, b=5) == '15'
# Check that echo_tool is not present
with pytest.raises(IndexError):
_ = tools[1]
with pytest.raises(KeyError):
_ = tools["echo_tool"]
def test_try_finally_with_filtered_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
try:
# Select both tools but in reverse order
mcp_server_adapter = MCPServerAdapter(serverparams, "calc_tool", "echo_tool")
tools = mcp_server_adapter.tools
assert len(tools) == 2
# The order of tools is based on filter_by_names which preserves
# the original order from the collection
assert tools[0].name == "calc_tool"
assert tools[1].name == "echo_tool"
finally:
mcp_server_adapter.stop()
def test_filter_with_nonexistent_tool(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
# Include a tool that doesn't exist
with MCPServerAdapter(serverparams, "echo_tool", "nonexistent_tool") as tools:
# Only echo_tool should be in the result
assert len(tools) == 1
assert tools[0].name == "echo_tool"
def test_filter_with_only_nonexistent_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
# All requested tools don't exist
with MCPServerAdapter(serverparams, "nonexistent1", "nonexistent2") as tools:
# Should return an empty tool collection
assert isinstance(tools, ToolCollection)
assert len(tools) == 0
def test_connect_timeout_parameter(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
with MCPServerAdapter(serverparams, connect_timeout=60) as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 2
assert tools[0].name == "echo_tool"
assert tools[1].name == "calc_tool"
assert tools[0].run(text="hello") == "Echo: hello"
def test_connect_timeout_with_filtered_tools(echo_server_script):
serverparams = StdioServerParameters(
command="uv", args=["run", "python", "-c", echo_server_script]
)
with MCPServerAdapter(serverparams, "echo_tool", connect_timeout=45) as tools:
assert isinstance(tools, ToolCollection)
assert len(tools) == 1
assert tools[0].name == "echo_tool"
assert tools[0].run(text="timeout test") == "Echo: timeout test"
@patch('crewai_tools.adapters.mcp_adapter.MCPAdapt')
def test_connect_timeout_passed_to_mcpadapt(mock_mcpadapt):
mock_adapter_instance = MagicMock()
mock_mcpadapt.return_value = mock_adapter_instance
serverparams = StdioServerParameters(
command="uv", args=["run", "echo", "test"]
)
MCPServerAdapter(serverparams)
mock_mcpadapt.assert_called_once()
assert mock_mcpadapt.call_args[0][2] == 30
mock_mcpadapt.reset_mock()
MCPServerAdapter(serverparams, connect_timeout=5)
mock_mcpadapt.assert_called_once()
assert mock_mcpadapt.call_args[0][2] == 5

104
tests/base_tool_test.py Normal file
View File

@@ -0,0 +1,104 @@
from typing import Callable
from crewai.tools import BaseTool, tool
from crewai.tools.base_tool import to_langchain
def test_creating_a_tool_using_annotation():
@tool("Name of my tool")
def my_tool(question: str) -> str:
"""Clear description for what this tool is useful for, you agent will need this information to use it."""
return question
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool.func("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.func("What is the meaning of life?")
== "What is the meaning of life?"
)
def test_creating_a_tool_using_baseclass():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.name == "Name of my tool"
assert (
my_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert my_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
my_tool._run("What is the meaning of life?") == "What is the meaning of life?"
)
# Assert the langchain tool conversion worked as expected
converted_tool = to_langchain([my_tool])[0]
assert converted_tool.name == "Name of my tool"
assert (
converted_tool.description
== "Tool Name: Name of my tool\nTool Arguments: {'question': {'description': None, 'type': 'str'}}\nTool Description: Clear description for what this tool is useful for, you agent will need this information to use it."
)
assert converted_tool.args_schema.model_json_schema()["properties"] == {
"question": {"title": "Question", "type": "string"}
}
assert (
converted_tool.invoke({"question": "What is the meaning of life?"})
== "What is the meaning of life?"
)
def test_setting_cache_function():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
cache_function: Callable = lambda: False
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == False
def test_default_cache_function_is_true():
class MyCustomTool(BaseTool):
name: str = "Name of my tool"
description: str = "Clear description for what this tool is useful for, you agent will need this information to use it."
def _run(self, question: str) -> str:
return question
my_tool = MyCustomTool()
# Assert all the right attributes were defined
assert my_tool.cache_function() == True

21
tests/conftest.py Normal file
View File

@@ -0,0 +1,21 @@
from typing import Callable
import pytest
class Helpers:
@staticmethod
def get_embedding_function() -> Callable:
def _func(input):
assert input == ["What are the requirements for the task?"]
with open("tests/data/embedding.txt", "r") as file:
content = file.read()
numbers = content.split(",")
return [[float(number) for number in numbers]]
return _func
@pytest.fixture
def helpers():
return Helpers

View File

@@ -0,0 +1,165 @@
import os
from unittest.mock import mock_open, patch
from crewai_tools import FileReadTool
def test_file_read_tool_constructor():
"""Test FileReadTool initialization with file_path."""
# Create a temporary test file
test_file = "/tmp/test_file.txt"
test_content = "Hello, World!"
with open(test_file, "w") as f:
f.write(test_content)
# Test initialization with file_path
tool = FileReadTool(file_path=test_file)
assert tool.file_path == test_file
assert "test_file.txt" in tool.description
# Clean up
os.remove(test_file)
def test_file_read_tool_run():
"""Test FileReadTool _run method with file_path at runtime."""
test_file = "/tmp/test_file.txt"
test_content = "Hello, World!"
# Use mock_open to mock file operations
with patch("builtins.open", mock_open(read_data=test_content)):
# Test reading file with runtime file_path
tool = FileReadTool()
result = tool._run(file_path=test_file)
assert result == test_content
def test_file_read_tool_error_handling():
"""Test FileReadTool error handling."""
# Test missing file path
tool = FileReadTool()
result = tool._run()
assert "Error: No file path provided" in result
# Test non-existent file
result = tool._run(file_path="/nonexistent/file.txt")
assert "Error: File not found at path:" in result
# Test permission error
with patch("builtins.open", side_effect=PermissionError()):
result = tool._run(file_path="/tmp/no_permission.txt")
assert "Error: Permission denied" in result
def test_file_read_tool_constructor_and_run():
"""Test FileReadTool using both constructor and runtime file paths."""
test_file1 = "/tmp/test1.txt"
test_file2 = "/tmp/test2.txt"
content1 = "File 1 content"
content2 = "File 2 content"
# First test with content1
with patch("builtins.open", mock_open(read_data=content1)):
tool = FileReadTool(file_path=test_file1)
result = tool._run()
assert result == content1
# Then test with content2 (should override constructor file_path)
with patch("builtins.open", mock_open(read_data=content2)):
result = tool._run(file_path=test_file2)
assert result == content2
def test_file_read_tool_chunk_reading():
"""Test FileReadTool reading specific chunks of a file."""
test_file = "/tmp/multiline_test.txt"
lines = [
"Line 1\n",
"Line 2\n",
"Line 3\n",
"Line 4\n",
"Line 5\n",
"Line 6\n",
"Line 7\n",
"Line 8\n",
"Line 9\n",
"Line 10\n",
]
file_content = "".join(lines)
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test reading a specific chunk (lines 3-5)
result = tool._run(file_path=test_file, start_line=3, line_count=3)
expected = "".join(lines[2:5]) # Lines are 0-indexed in the array
assert result == expected
# Test reading from a specific line to the end
result = tool._run(file_path=test_file, start_line=8)
expected = "".join(lines[7:])
assert result == expected
# Test with default values (should read entire file)
result = tool._run(file_path=test_file)
expected = "".join(lines)
assert result == expected
# Test when start_line is 1 but line_count is specified
result = tool._run(file_path=test_file, start_line=1, line_count=5)
expected = "".join(lines[0:5])
assert result == expected
def test_file_read_tool_chunk_error_handling():
"""Test error handling for chunk reading."""
test_file = "/tmp/short_test.txt"
lines = ["Line 1\n", "Line 2\n", "Line 3\n"]
file_content = "".join(lines)
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test start_line exceeding file length
result = tool._run(file_path=test_file, start_line=10)
assert "Error: Start line 10 exceeds the number of lines in the file" in result
# Test reading partial chunk when line_count exceeds available lines
result = tool._run(file_path=test_file, start_line=2, line_count=10)
expected = "".join(lines[1:]) # Should return from line 2 to end
assert result == expected
def test_file_read_tool_zero_or_negative_start_line():
"""Test that start_line values of 0 or negative read from the start of the file."""
test_file = "/tmp/negative_test.txt"
lines = ["Line 1\n", "Line 2\n", "Line 3\n", "Line 4\n", "Line 5\n"]
file_content = "".join(lines)
with patch("builtins.open", mock_open(read_data=file_content)):
tool = FileReadTool()
# Test with start_line = None
result = tool._run(file_path=test_file, start_line=None)
expected = "".join(lines) # Should read the entire file
assert result == expected
# Test with start_line = 0
result = tool._run(file_path=test_file, start_line=0)
expected = "".join(lines) # Should read the entire file
assert result == expected
# Test with start_line = 0 and limited line count
result = tool._run(file_path=test_file, start_line=0, line_count=3)
expected = "".join(lines[0:3]) # Should read first 3 lines
assert result == expected
# Test with negative start_line
result = tool._run(file_path=test_file, start_line=-5)
expected = "".join(lines) # Should read the entire file
assert result == expected
# Test with negative start_line and limited line count
result = tool._run(file_path=test_file, start_line=-10, line_count=2)
expected = "".join(lines[0:2]) # Should read first 2 lines
assert result == expected

View File

View File

@@ -0,0 +1,21 @@
import pytest
def pytest_configure(config):
"""Register custom markers."""
config.addinivalue_line("markers", "integration: mark test as an integration test")
config.addinivalue_line("markers", "asyncio: mark test as an async test")
# Set the asyncio loop scope through ini configuration
config.inicfg["asyncio_mode"] = "auto"
@pytest.fixture(scope="function")
def event_loop():
"""Create an instance of the default event loop for each test case."""
import asyncio
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
yield loop
loop.close()

0
tests/rag/__init__.py Normal file
View File

View File

@@ -0,0 +1,130 @@
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.csv_loader import CSVLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
@pytest.fixture
def temp_csv_file():
created_files = []
def _create(content: str):
f = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False)
f.write(content)
f.close()
created_files.append(f.name)
return f.name
yield _create
for path in created_files:
os.unlink(path)
class TestCSVLoader:
def test_load_csv_from_file(self, temp_csv_file):
path = temp_csv_file("name,age,city\nJohn,25,New York\nJane,30,Chicago")
loader = CSVLoader()
result = loader.load(SourceContent(path))
assert isinstance(result, LoaderResult)
assert "Headers: name | age | city" in result.content
assert "Row 1: name: John | age: 25 | city: New York" in result.content
assert "Row 2: name: Jane | age: 30 | city: Chicago" in result.content
assert result.metadata == {
"format": "csv",
"columns": ["name", "age", "city"],
"rows": 2,
}
assert result.source == path
assert result.doc_id
def test_load_csv_with_empty_values(self, temp_csv_file):
path = temp_csv_file("name,age,city\nJohn,,New York\n,30,")
result = CSVLoader().load(SourceContent(path))
assert "Row 1: name: John | city: New York" in result.content
assert "Row 2: age: 30" in result.content
assert result.metadata["rows"] == 2
def test_load_csv_malformed(self, temp_csv_file):
path = temp_csv_file("invalid,csv\nunclosed quote \"missing")
result = CSVLoader().load(SourceContent(path))
assert "Headers: invalid | csv" in result.content
assert 'Row 1: invalid: unclosed quote "missing' in result.content
assert result.metadata["columns"] == ["invalid", "csv"]
def test_load_csv_empty_file(self, temp_csv_file):
path = temp_csv_file("")
result = CSVLoader().load(SourceContent(path))
assert result.content == ""
assert result.metadata["rows"] == 0
def test_load_csv_text_input(self):
raw_csv = "col1,col2\nvalue1,value2\nvalue3,value4"
result = CSVLoader().load(SourceContent(raw_csv))
assert "Headers: col1 | col2" in result.content
assert "Row 1: col1: value1 | col2: value2" in result.content
assert "Row 2: col1: value3 | col2: value4" in result.content
assert result.metadata["columns"] == ["col1", "col2"]
assert result.metadata["rows"] == 2
def test_doc_id_is_deterministic(self, temp_csv_file):
path = temp_csv_file("name,value\ntest,123")
loader = CSVLoader()
result1 = loader.load(SourceContent(path))
result2 = loader.load(SourceContent(path))
assert result1.doc_id == result2.doc_id
@patch("requests.get")
def test_load_csv_from_url(self, mock_get):
mock_get.return_value = Mock(
text="name,value\ntest,123",
raise_for_status=Mock(return_value=None)
)
result = CSVLoader().load(SourceContent("https://example.com/data.csv"))
assert "Headers: name | value" in result.content
assert "Row 1: name: test | value: 123" in result.content
headers = mock_get.call_args[1]["headers"]
assert "text/csv" in headers["Accept"]
assert "crewai-tools CSVLoader" in headers["User-Agent"]
@patch("requests.get")
def test_load_csv_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(
text="data,value\ntest,456",
raise_for_status=Mock(return_value=None)
)
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
result = CSVLoader().load(SourceContent("https://example.com/data.csv"), headers=headers)
assert "Headers: data | value" in result.content
assert mock_get.call_args[1]["headers"] == headers
@patch("requests.get")
def test_csv_loader_handles_network_errors(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = CSVLoader()
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
loader.load(SourceContent("https://example.com/data.csv"))
@patch("requests.get")
def test_csv_loader_handles_http_error(self, mock_get):
mock_get.return_value = Mock()
mock_get.return_value.raise_for_status.side_effect = Exception("404 Not Found")
loader = CSVLoader()
with pytest.raises(ValueError, match="Error fetching CSV from URL"):
loader.load(SourceContent("https://example.com/notfound.csv"))

View File

@@ -0,0 +1,149 @@
import os
import tempfile
import pytest
from crewai_tools.rag.loaders.directory_loader import DirectoryLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
@pytest.fixture
def temp_directory():
with tempfile.TemporaryDirectory() as temp_dir:
yield temp_dir
class TestDirectoryLoader:
def _create_file(self, directory, filename, content="test content"):
path = os.path.join(directory, filename)
with open(path, "w") as f:
f.write(content)
return path
def test_load_non_recursive(self, temp_directory):
self._create_file(temp_directory, "file1.txt")
self._create_file(temp_directory, "file2.txt")
subdir = os.path.join(temp_directory, "subdir")
os.makedirs(subdir)
self._create_file(subdir, "file3.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=False)
assert isinstance(result, LoaderResult)
assert "file1.txt" in result.content
assert "file2.txt" in result.content
assert "file3.txt" not in result.content
assert result.metadata["total_files"] == 2
def test_load_recursive(self, temp_directory):
self._create_file(temp_directory, "file1.txt")
nested = os.path.join(temp_directory, "subdir", "nested")
os.makedirs(nested)
self._create_file(os.path.join(temp_directory, "subdir"), "file2.txt")
self._create_file(nested, "file3.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=True)
assert all(f"file{i}.txt" in result.content for i in range(1, 4))
def test_include_and_exclude_extensions(self, temp_directory):
self._create_file(temp_directory, "a.txt")
self._create_file(temp_directory, "b.py")
self._create_file(temp_directory, "c.md")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), include_extensions=[".txt", ".py"])
assert "a.txt" in result.content
assert "b.py" in result.content
assert "c.md" not in result.content
result2 = loader.load(SourceContent(temp_directory), exclude_extensions=[".py", ".md"])
assert "a.txt" in result2.content
assert "b.py" not in result2.content
assert "c.md" not in result2.content
def test_max_files_limit(self, temp_directory):
for i in range(5):
self._create_file(temp_directory, f"file{i}.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), max_files=3)
assert result.metadata["total_files"] == 3
assert all(f"file{i}.txt" in result.content for i in range(3))
def test_hidden_files_and_dirs_excluded(self, temp_directory):
self._create_file(temp_directory, "visible.txt", "visible")
self._create_file(temp_directory, ".hidden.txt", "hidden")
hidden_dir = os.path.join(temp_directory, ".hidden")
os.makedirs(hidden_dir)
self._create_file(hidden_dir, "inside_hidden.txt")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory), recursive=True)
assert "visible.txt" in result.content
assert ".hidden.txt" not in result.content
assert "inside_hidden.txt" not in result.content
def test_directory_does_not_exist(self):
loader = DirectoryLoader()
with pytest.raises(FileNotFoundError, match="Directory does not exist"):
loader.load(SourceContent("/path/does/not/exist"))
def test_path_is_not_a_directory(self):
with tempfile.NamedTemporaryFile() as f:
loader = DirectoryLoader()
with pytest.raises(ValueError, match="Path is not a directory"):
loader.load(SourceContent(f.name))
def test_url_not_supported(self):
loader = DirectoryLoader()
with pytest.raises(ValueError, match="URL directory loading is not supported"):
loader.load(SourceContent("https://example.com"))
def test_processing_error_handling(self, temp_directory):
self._create_file(temp_directory, "valid.txt")
error_file = self._create_file(temp_directory, "error.txt")
loader = DirectoryLoader()
original_method = loader._process_single_file
def mock(file_path):
if "error" in file_path:
raise ValueError("Mock error")
return original_method(file_path)
loader._process_single_file = mock
result = loader.load(SourceContent(temp_directory))
assert "valid.txt" in result.content
assert "error.txt (ERROR)" in result.content
assert result.metadata["errors"] == 1
assert len(result.metadata["error_details"]) == 1
def test_metadata_structure(self, temp_directory):
self._create_file(temp_directory, "test.txt", "Sample")
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory))
metadata = result.metadata
expected_keys = {
"format", "directory_path", "total_files", "processed_files",
"errors", "file_details", "error_details"
}
assert expected_keys.issubset(metadata)
assert all(k in metadata["file_details"][0] for k in ("path", "metadata", "source"))
def test_empty_directory(self, temp_directory):
loader = DirectoryLoader()
result = loader.load(SourceContent(temp_directory))
assert result.content == ""
assert result.metadata["total_files"] == 0
assert result.metadata["processed_files"] == 0

View File

@@ -0,0 +1,135 @@
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.docx_loader import DOCXLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestDOCXLoader:
@patch('docx.Document')
def test_load_docx_from_file(self, mock_docx_class):
mock_doc = Mock()
mock_doc.paragraphs = [
Mock(text="First paragraph"),
Mock(text="Second paragraph"),
Mock(text=" ") # Blank paragraph
]
mock_doc.tables = []
mock_docx_class.return_value = mock_doc
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert isinstance(result, LoaderResult)
assert result.content == "First paragraph\nSecond paragraph"
assert result.metadata == {"format": "docx", "paragraphs": 3, "tables": 0}
assert result.source == f.name
@patch('docx.Document')
def test_load_docx_with_tables(self, mock_docx_class):
mock_doc = Mock()
mock_doc.paragraphs = [Mock(text="Document with table")]
mock_doc.tables = [Mock(), Mock()]
mock_docx_class.return_value = mock_doc
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert result.metadata["tables"] == 2
@patch('requests.get')
@patch('docx.Document')
@patch('tempfile.NamedTemporaryFile')
@patch('os.unlink')
def test_load_docx_from_url(self, mock_unlink, mock_tempfile, mock_docx_class, mock_get):
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
mock_temp = Mock(name="/tmp/temp_docx_file.docx")
mock_temp.__enter__ = Mock(return_value=mock_temp)
mock_temp.__exit__ = Mock(return_value=None)
mock_tempfile.return_value = mock_temp
mock_doc = Mock()
mock_doc.paragraphs = [Mock(text="Content from URL")]
mock_doc.tables = []
mock_docx_class.return_value = mock_doc
loader = DOCXLoader()
result = loader.load(SourceContent("https://example.com/test.docx"))
assert "Content from URL" in result.content
assert result.source == "https://example.com/test.docx"
headers = mock_get.call_args[1]['headers']
assert "application/vnd.openxmlformats-officedocument.wordprocessingml.document" in headers['Accept']
assert "crewai-tools DOCXLoader" in headers['User-Agent']
mock_temp.write.assert_called_once_with(b"fake docx content")
@patch('requests.get')
@patch('docx.Document')
def test_load_docx_from_url_with_custom_headers(self, mock_docx_class, mock_get):
mock_get.return_value = Mock(content=b"fake docx content", raise_for_status=Mock())
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
loader = DOCXLoader()
custom_headers = {"Authorization": "Bearer token"}
with patch('tempfile.NamedTemporaryFile'), patch('os.unlink'):
loader.load(SourceContent("https://example.com/test.docx"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_load_docx_url_download_error(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
loader.load(SourceContent("https://example.com/test.docx"))
@patch('requests.get')
def test_load_docx_url_http_error(self, mock_get):
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404 Not Found")))
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error fetching DOCX from URL"):
loader.load(SourceContent("https://example.com/notfound.docx"))
def test_load_docx_invalid_source(self):
loader = DOCXLoader()
with pytest.raises(ValueError, match="Source must be a valid file path or URL"):
loader.load(SourceContent("not_a_file_or_url"))
@patch('docx.Document')
def test_load_docx_parsing_error(self, mock_docx_class):
mock_docx_class.side_effect = Exception("Invalid DOCX file")
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
with pytest.raises(ValueError, match="Error loading DOCX file"):
loader.load(SourceContent(f.name))
@patch('docx.Document')
def test_load_docx_empty_document(self, mock_docx_class):
mock_docx_class.return_value = Mock(paragraphs=[], tables=[])
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
result = loader.load(SourceContent(f.name))
assert result.content == ""
assert result.metadata == {"paragraphs": 0, "tables": 0, "format": "docx"}
@patch('docx.Document')
def test_docx_doc_id_generation(self, mock_docx_class):
mock_docx_class.return_value = Mock(paragraphs=[Mock(text="Consistent content")], tables=[])
with tempfile.NamedTemporaryFile(suffix='.docx') as f:
loader = DOCXLoader()
source = SourceContent(f.name)
assert loader.load(source).doc_id == loader.load(source).doc_id

View File

@@ -0,0 +1,180 @@
import json
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.json_loader import JSONLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestJSONLoader:
def _create_temp_json_file(self, data) -> str:
"""Helper to write JSON data to a temporary file and return its path."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
json.dump(data, f)
return f.name
def _create_temp_raw_file(self, content: str) -> str:
"""Helper to write raw content to a temporary file and return its path."""
with tempfile.NamedTemporaryFile(mode='w', suffix='.json', delete=False) as f:
f.write(content)
return f.name
def _load_from_path(self, path) -> LoaderResult:
loader = JSONLoader()
return loader.load(SourceContent(path))
def test_load_json_dict(self):
path = self._create_temp_json_file({"name": "John", "age": 30, "items": ["a", "b", "c"]})
try:
result = self._load_from_path(path)
assert isinstance(result, LoaderResult)
assert all(k in result.content for k in ["name", "John", "age", "30"])
assert result.metadata == {
"format": "json", "type": "dict", "size": 3
}
assert result.source == path
finally:
os.unlink(path)
def test_load_json_list(self):
path = self._create_temp_json_file([
{"id": 1, "name": "Item 1"},
{"id": 2, "name": "Item 2"},
])
try:
result = self._load_from_path(path)
assert result.metadata["type"] == "list"
assert result.metadata["size"] == 2
assert all(item in result.content for item in ["Item 1", "Item 2"])
finally:
os.unlink(path)
@pytest.mark.parametrize("value, expected_type", [
("simple string value", "str"),
(42, "int"),
])
def test_load_json_primitives(self, value, expected_type):
path = self._create_temp_json_file(value)
try:
result = self._load_from_path(path)
assert result.metadata["type"] == expected_type
assert result.metadata["size"] == 1
assert str(value) in result.content
finally:
os.unlink(path)
def test_load_malformed_json(self):
path = self._create_temp_raw_file('{"invalid": json,}')
try:
result = self._load_from_path(path)
assert result.metadata["format"] == "json"
assert "parse_error" in result.metadata
assert result.content == '{"invalid": json,}'
finally:
os.unlink(path)
def test_load_empty_file(self):
path = self._create_temp_raw_file('')
try:
result = self._load_from_path(path)
assert "parse_error" in result.metadata
assert result.content == ''
finally:
os.unlink(path)
def test_load_text_input(self):
json_text = '{"message": "hello", "count": 5}'
loader = JSONLoader()
result = loader.load(SourceContent(json_text))
assert all(part in result.content for part in ["message", "hello", "count", "5"])
assert result.metadata["type"] == "dict"
assert result.metadata["size"] == 2
def test_load_complex_nested_json(self):
data = {
"users": [
{"id": 1, "profile": {"name": "Alice", "settings": {"theme": "dark"}}},
{"id": 2, "profile": {"name": "Bob", "settings": {"theme": "light"}}}
],
"meta": {"total": 2, "version": "1.0"}
}
path = self._create_temp_json_file(data)
try:
result = self._load_from_path(path)
for value in ["Alice", "Bob", "dark", "light"]:
assert value in result.content
assert result.metadata["size"] == 2 # top-level keys
finally:
os.unlink(path)
def test_consistent_doc_id(self):
path = self._create_temp_json_file({"test": "data"})
try:
result1 = self._load_from_path(path)
result2 = self._load_from_path(path)
assert result1.doc_id == result2.doc_id
finally:
os.unlink(path)
# ------------------------------
# URL-based tests
# ------------------------------
@patch('requests.get')
def test_url_response_valid_json(self, mock_get):
mock_get.return_value = Mock(
text='{"key": "value", "number": 123}',
json=Mock(return_value={"key": "value", "number": 123}),
raise_for_status=Mock()
)
loader = JSONLoader()
result = loader.load(SourceContent("https://api.example.com/data.json"))
assert all(val in result.content for val in ["key", "value", "number", "123"])
headers = mock_get.call_args[1]['headers']
assert "application/json" in headers['Accept']
assert "crewai-tools JSONLoader" in headers['User-Agent']
@patch('requests.get')
def test_url_response_not_json(self, mock_get):
mock_get.return_value = Mock(
text='{"key": "value"}',
json=Mock(side_effect=ValueError("Not JSON")),
raise_for_status=Mock()
)
loader = JSONLoader()
result = loader.load(SourceContent("https://example.com/data.json"))
assert all(part in result.content for part in ["key", "value"])
@patch('requests.get')
def test_url_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(
text='{"data": "test"}',
json=Mock(return_value={"data": "test"}),
raise_for_status=Mock()
)
headers = {"Authorization": "Bearer token", "Custom-Header": "value"}
loader = JSONLoader()
loader.load(SourceContent("https://api.example.com/data.json"), headers=headers)
assert mock_get.call_args[1]['headers'] == headers
@patch('requests.get')
def test_url_network_failure(self, mock_get):
mock_get.side_effect = Exception("Network error")
loader = JSONLoader()
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
loader.load(SourceContent("https://api.example.com/data.json"))
@patch('requests.get')
def test_url_http_error(self, mock_get):
mock_get.return_value = Mock(raise_for_status=Mock(side_effect=Exception("404")))
loader = JSONLoader()
with pytest.raises(ValueError, match="Error fetching JSON from URL"):
loader.load(SourceContent("https://api.example.com/404.json"))

View File

@@ -0,0 +1,176 @@
import os
import tempfile
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.mdx_loader import MDXLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestMDXLoader:
def _write_temp_mdx(self, content):
f = tempfile.NamedTemporaryFile(mode='w', suffix='.mdx', delete=False)
f.write(content)
f.close()
return f.name
def _load_from_file(self, content):
path = self._write_temp_mdx(content)
try:
loader = MDXLoader()
return loader.load(SourceContent(path)), path
finally:
os.unlink(path)
def test_load_basic_mdx_file(self):
content = """
import Component from './Component'
export const meta = { title: 'Test' }
# Test MDX File
This is a **markdown** file with JSX.
<Component prop="value" />
Some more content.
<div className="container">
<p>Nested content</p>
</div>
"""
result, path = self._load_from_file(content)
assert isinstance(result, LoaderResult)
assert all(tag not in result.content for tag in ["import", "export", "<Component", "<div", "</div>"])
assert all(text in result.content for text in ["# Test MDX File", "markdown", "Some more content", "Nested content"])
assert result.metadata["format"] == "mdx"
assert result.source == path
def test_mdx_multiple_imports_exports(self):
content = """
import React from 'react'
import { useState } from 'react'
import CustomComponent from './custom'
export default function Layout() { return null }
export const config = { test: true }
# Content
Regular markdown content here.
"""
result, _ = self._load_from_file(content)
assert "# Content" in result.content
assert "Regular markdown content here." in result.content
assert "import" not in result.content and "export" not in result.content
def test_complex_jsx_cleanup(self):
content = """
# MDX with Complex JSX
<div className="alert alert-info">
<strong>Info:</strong> This is important information.
<ul><li>Item 1</li><li>Item 2</li></ul>
</div>
Regular paragraph text.
<MyComponent prop1="value1">Nested content inside component</MyComponent>
"""
result, _ = self._load_from_file(content)
assert all(tag not in result.content for tag in ["<div", "<strong>", "<ul>", "<MyComponent"])
assert all(text in result.content for text in ["Info:", "Item 1", "Regular paragraph text.", "Nested content inside component"])
def test_whitespace_cleanup(self):
content = """
# Title
Some content.
More content after multiple newlines.
Final content.
"""
result, _ = self._load_from_file(content)
assert result.content.count('\n\n\n') == 0
assert result.content.startswith('# Title')
assert result.content.endswith('Final content.')
def test_only_jsx_content(self):
content = """
<div>
<h1>Only JSX content</h1>
<p>No markdown here</p>
</div>
"""
result, _ = self._load_from_file(content)
assert all(tag not in result.content for tag in ["<div>", "<h1>", "<p>"])
assert "Only JSX content" in result.content
assert "No markdown here" in result.content
@patch('requests.get')
def test_load_mdx_from_url(self, mock_get):
mock_get.return_value = Mock(text="# MDX from URL\n\nContent here.\n\n<Component />", raise_for_status=lambda: None)
loader = MDXLoader()
result = loader.load(SourceContent("https://example.com/content.mdx"))
assert "# MDX from URL" in result.content
assert "<Component />" not in result.content
@patch('requests.get')
def test_load_mdx_with_custom_headers(self, mock_get):
mock_get.return_value = Mock(text="# Custom headers test", raise_for_status=lambda: None)
loader = MDXLoader()
loader.load(SourceContent("https://example.com"), headers={"Authorization": "Bearer token"})
assert mock_get.call_args[1]['headers'] == {"Authorization": "Bearer token"}
@patch('requests.get')
def test_mdx_url_fetch_error(self, mock_get):
mock_get.side_effect = Exception("Network error")
with pytest.raises(ValueError, match="Error fetching MDX from URL"):
MDXLoader().load(SourceContent("https://example.com"))
def test_load_inline_mdx_text(self):
content = """# Inline MDX\n\nimport Something from 'somewhere'\n\nContent with <Component prop=\"value\" />.\n\nexport const meta = { title: 'Test' }"""
loader = MDXLoader()
result = loader.load(SourceContent(content))
assert "# Inline MDX" in result.content
assert "Content with ." in result.content
def test_empty_result_after_cleaning(self):
content = """
import Something from 'somewhere'
export const config = {}
<div></div>
"""
result, _ = self._load_from_file(content)
assert result.content.strip() == ""
def test_edge_case_parsing(self):
content = """
# Title
<Component>
Multi-line
JSX content
</Component>
import { a, b } from 'module'
export { x, y }
Final text.
"""
result, _ = self._load_from_file(content)
assert "# Title" in result.content
assert "JSX content" in result.content
assert "Final text." in result.content
assert all(phrase not in result.content for phrase in ["import {", "export {", "<Component>"])

View File

@@ -0,0 +1,160 @@
import hashlib
import os
import tempfile
import pytest
from crewai_tools.rag.loaders.text_loader import TextFileLoader, TextLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
def write_temp_file(content, suffix=".txt", encoding="utf-8"):
with tempfile.NamedTemporaryFile(mode="w", suffix=suffix, delete=False, encoding=encoding) as f:
f.write(content)
return f.name
def cleanup_temp_file(path):
try:
os.unlink(path)
except FileNotFoundError:
pass
class TestTextFileLoader:
def test_basic_text_file(self):
content = "This is test content\nWith multiple lines\nAnd more text"
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert isinstance(result, LoaderResult)
assert result.content == content
assert result.source == path
assert result.doc_id
assert result.metadata in (None, {})
finally:
cleanup_temp_file(path)
def test_empty_file(self):
path = write_temp_file("")
try:
result = TextFileLoader().load(SourceContent(path))
assert result.content == ""
finally:
cleanup_temp_file(path)
def test_unicode_content(self):
content = "Hello 世界 🌍 émojis 🎉 åäö"
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert content in result.content
finally:
cleanup_temp_file(path)
def test_large_file(self):
content = "\n".join(f"Line {i}" for i in range(100))
path = write_temp_file(content)
try:
result = TextFileLoader().load(SourceContent(path))
assert "Line 0" in result.content
assert "Line 99" in result.content
assert result.content.count("\n") == 99
finally:
cleanup_temp_file(path)
def test_missing_file(self):
with pytest.raises(FileNotFoundError):
TextFileLoader().load(SourceContent("/nonexistent/path.txt"))
def test_permission_denied(self):
path = write_temp_file("Some content")
os.chmod(path, 0o000)
try:
with pytest.raises(PermissionError):
TextFileLoader().load(SourceContent(path))
finally:
os.chmod(path, 0o644)
cleanup_temp_file(path)
def test_doc_id_consistency(self):
content = "Consistent content"
path = write_temp_file(content)
try:
loader = TextFileLoader()
result1 = loader.load(SourceContent(path))
result2 = loader.load(SourceContent(path))
expected_id = hashlib.sha256((path + content).encode("utf-8")).hexdigest()
assert result1.doc_id == result2.doc_id == expected_id
finally:
cleanup_temp_file(path)
def test_various_extensions(self):
content = "Same content"
for ext in [".txt", ".md", ".log", ".json"]:
path = write_temp_file(content, suffix=ext)
try:
result = TextFileLoader().load(SourceContent(path))
assert result.content == content
finally:
cleanup_temp_file(path)
class TestTextLoader:
def test_basic_text(self):
content = "Raw text"
result = TextLoader().load(SourceContent(content))
expected_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()
assert result.content == content
assert result.source == expected_hash
assert result.doc_id == expected_hash
def test_multiline_text(self):
content = "Line 1\nLine 2\nLine 3"
result = TextLoader().load(SourceContent(content))
assert "Line 2" in result.content
def test_empty_text(self):
result = TextLoader().load(SourceContent(""))
assert result.content == ""
assert result.source == hashlib.sha256("".encode("utf-8")).hexdigest()
def test_unicode_text(self):
content = "世界 🌍 émojis 🎉 åäö"
result = TextLoader().load(SourceContent(content))
assert content in result.content
def test_special_characters(self):
content = "!@#$$%^&*()_+-=~`{}[]\\|;:'\",.<>/?"
result = TextLoader().load(SourceContent(content))
assert result.content == content
def test_doc_id_uniqueness(self):
result1 = TextLoader().load(SourceContent("A"))
result2 = TextLoader().load(SourceContent("B"))
assert result1.doc_id != result2.doc_id
def test_whitespace_text(self):
content = " \n\t "
result = TextLoader().load(SourceContent(content))
assert result.content == content
def test_long_text(self):
content = "A" * 10000
result = TextLoader().load(SourceContent(content))
assert len(result.content) == 10000
class TestTextLoadersIntegration:
def test_consistency_between_loaders(self):
content = "Consistent content"
text_result = TextLoader().load(SourceContent(content))
file_path = write_temp_file(content)
try:
file_result = TextFileLoader().load(SourceContent(file_path))
assert text_result.content == file_result.content
assert text_result.source != file_result.source
assert text_result.doc_id != file_result.doc_id
finally:
cleanup_temp_file(file_path)

View File

@@ -0,0 +1,137 @@
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestWebPageLoader:
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
response = Mock()
response.text = text
response.apparent_encoding = "utf-8"
response.status_code = status_code
response.headers = {"content-type": content_type}
return response
def setup_mock_soup(self, text, title=None, script_style_elements=None):
soup = Mock()
soup.get_text.return_value = text
soup.title = Mock(string=title) if title is not None else None
soup.return_value = script_style_elements or []
return soup
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert isinstance(result, LoaderResult)
assert result.content == "Test content"
assert result.metadata["title"] == "Test Page"
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
<body><script>console.log('test');</script><p>Visible content</p></body></html>
"""
mock_get.return_value = self.setup_mock_response(html)
scripts = [Mock(), Mock()]
styles = [Mock()]
for el in scripts + styles:
el.decompose = Mock()
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/with-scripts"))
assert "Visible content" in result.content
for el in scripts + styles:
el.decompose.assert_called_once()
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/messy-text"))
assert result.content is not None
assert result.metadata["title"] == ""
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]:
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == ""
@patch('requests.get')
def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
mock_bs.return_value = self.setup_mock_soup("Test")
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com"))
@patch('requests.get')
def test_timeout_and_http_error(self, mock_get):
import requests
mock_get.side_effect = requests.Timeout("Timeout")
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com"))
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
mock_get.side_effect = None
mock_get.return_value = mock_response
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404"))
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
mock_bs.return_value = self.setup_mock_soup("Doc")
loader = WebPageLoader()
result1 = loader.load(SourceContent("https://example.com"))
result2 = loader.load(SourceContent("https://example.com"))
assert result1.doc_id == result2.doc_id
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]:
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
assert result.metadata["status_code"] == status
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
mock_bs.return_value = self.setup_mock_soup("Content")
result = WebPageLoader().load(SourceContent("https://example.com"))
assert result.metadata["content_type"] == ctype

View File

@@ -0,0 +1,137 @@
import pytest
from unittest.mock import patch, Mock
from crewai_tools.rag.loaders.webpage_loader import WebPageLoader
from crewai_tools.rag.base_loader import LoaderResult
from crewai_tools.rag.source_content import SourceContent
class TestWebPageLoader:
def setup_mock_response(self, text, status_code=200, content_type="text/html"):
response = Mock()
response.text = text
response.apparent_encoding = "utf-8"
response.status_code = status_code
response.headers = {"content-type": content_type}
return response
def setup_mock_soup(self, text, title=None, script_style_elements=None):
soup = Mock()
soup.get_text.return_value = text
soup.title = Mock(string=title) if title is not None else None
soup.return_value = script_style_elements or []
return soup
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_basic_webpage(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><head><title>Test Page</title></head><body><p>Test content</p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Test content", title="Test Page")
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert isinstance(result, LoaderResult)
assert result.content == "Test content"
assert result.metadata["title"] == "Test Page"
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_load_webpage_with_scripts_and_styles(self, mock_bs, mock_get):
html = """
<html><head><title>Page with Scripts</title><style>body { color: red; }</style></head>
<body><script>console.log('test');</script><p>Visible content</p></body></html>
"""
mock_get.return_value = self.setup_mock_response(html)
scripts = [Mock(), Mock()]
styles = [Mock()]
for el in scripts + styles:
el.decompose = Mock()
mock_bs.return_value = self.setup_mock_soup("Page with Scripts Visible content", title="Page with Scripts", script_style_elements=scripts + styles)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/with-scripts"))
assert "Visible content" in result.content
for el in scripts + styles:
el.decompose.assert_called_once()
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_text_cleaning_and_title_handling(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body><p> Messy text </p></body></html>")
mock_bs.return_value = self.setup_mock_soup("Text with extra spaces\n\n More\t text \n\n", title=None)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com/messy-text"))
assert result.content is not None
assert result.metadata["title"] == ""
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_empty_or_missing_title(self, mock_bs, mock_get):
for title in [None, ""]:
mock_get.return_value = self.setup_mock_response("<html><head><title></title></head><body>Content</body></html>")
mock_bs.return_value = self.setup_mock_soup("Content", title=title)
loader = WebPageLoader()
result = loader.load(SourceContent("https://example.com"))
assert result.metadata["title"] == ""
@patch('requests.get')
def test_custom_and_default_headers(self, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Test</body></html>")
custom_headers = {"User-Agent": "Bot", "Authorization": "Bearer xyz", "Accept": "text/html"}
with patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup') as mock_bs:
mock_bs.return_value = self.setup_mock_soup("Test")
WebPageLoader().load(SourceContent("https://example.com"), headers=custom_headers)
assert mock_get.call_args[1]['headers'] == custom_headers
@patch('requests.get')
def test_error_handling(self, mock_get):
for error in [Exception("Fail"), ValueError("Bad"), ImportError("Oops")]:
mock_get.side_effect = error
with pytest.raises(ValueError, match="Error loading webpage"):
WebPageLoader().load(SourceContent("https://example.com"))
@patch('requests.get')
def test_timeout_and_http_error(self, mock_get):
import requests
mock_get.side_effect = requests.Timeout("Timeout")
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com"))
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404")
mock_get.side_effect = None
mock_get.return_value = mock_response
with pytest.raises(ValueError):
WebPageLoader().load(SourceContent("https://example.com/404"))
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_doc_id_consistency(self, mock_bs, mock_get):
mock_get.return_value = self.setup_mock_response("<html><body>Doc</body></html>")
mock_bs.return_value = self.setup_mock_soup("Doc")
loader = WebPageLoader()
result1 = loader.load(SourceContent("https://example.com"))
result2 = loader.load(SourceContent("https://example.com"))
assert result1.doc_id == result2.doc_id
@patch('requests.get')
@patch('crewai_tools.rag.loaders.webpage_loader.BeautifulSoup')
def test_status_code_and_content_type(self, mock_bs, mock_get):
for status in [200, 201, 301]:
mock_get.return_value = self.setup_mock_response(f"<html><body>Status {status}</body></html>", status_code=status)
mock_bs.return_value = self.setup_mock_soup(f"Status {status}")
result = WebPageLoader().load(SourceContent(f"https://example.com/{status}"))
assert result.metadata["status_code"] == status
for ctype in ["text/html", "text/plain", "application/xhtml+xml"]:
mock_get.return_value = self.setup_mock_response("<html><body>Content</body></html>", content_type=ctype)
mock_bs.return_value = self.setup_mock_soup("Content")
result = WebPageLoader().load(SourceContent("https://example.com"))
assert result.metadata["content_type"] == ctype

View File

@@ -0,0 +1,193 @@
import json
from typing import List, Optional, Type
from unittest import mock
import pytest
from crewai.tools.base_tool import BaseTool, EnvVar
from pydantic import BaseModel, Field
from generate_tool_specs import ToolSpecExtractor
class MockToolSchema(BaseModel):
query: str = Field(..., description="The query parameter")
count: int = Field(5, description="Number of results to return")
filters: Optional[List[str]] = Field(None, description="Optional filters to apply")
class MockTool(BaseTool):
name: str = "Mock Search Tool"
description: str = "A tool that mocks search functionality"
args_schema: Type[BaseModel] = MockToolSchema
another_parameter: str = Field(
"Another way to define a default value", description=""
)
my_parameter: str = Field("This is default value", description="What a description")
my_parameter_bool: bool = Field(False)
package_dependencies: List[str] = Field(
["this-is-a-required-package", "another-required-package"], description=""
)
env_vars: List[EnvVar] = [
EnvVar(
name="SERPER_API_KEY",
description="API key for Serper",
required=True,
default=None,
),
EnvVar(
name="API_RATE_LIMIT",
description="API rate limit",
required=False,
default="100",
),
]
@pytest.fixture
def extractor():
ext = ToolSpecExtractor()
return ext
def test_unwrap_schema(extractor):
nested_schema = {
"type": "function-after",
"schema": {"type": "default", "schema": {"type": "str", "value": "test"}},
}
result = extractor._unwrap_schema(nested_schema)
assert result["type"] == "str"
assert result["value"] == "test"
@pytest.fixture
def mock_tool_extractor(extractor):
with (
mock.patch("generate_tool_specs.dir", return_value=["MockTool"]),
mock.patch("generate_tool_specs.getattr", return_value=MockTool),
):
extractor.extract_all_tools()
assert len(extractor.tools_spec) == 1
return extractor.tools_spec[0]
def test_extract_basic_tool_info(mock_tool_extractor):
tool_info = mock_tool_extractor
assert tool_info.keys() == {
"name",
"humanized_name",
"description",
"run_params_schema",
"env_vars",
"init_params_schema",
"package_dependencies",
}
assert tool_info["name"] == "MockTool"
assert tool_info["humanized_name"] == "Mock Search Tool"
assert tool_info["description"] == "A tool that mocks search functionality"
def test_extract_init_params_schema(mock_tool_extractor):
tool_info = mock_tool_extractor
init_params_schema = tool_info["init_params_schema"]
assert init_params_schema.keys() == {
"$defs",
"properties",
"title",
"type",
}
another_parameter = init_params_schema["properties"]["another_parameter"]
assert another_parameter["description"] == ""
assert another_parameter["default"] == "Another way to define a default value"
assert another_parameter["type"] == "string"
my_parameter = init_params_schema["properties"]["my_parameter"]
assert my_parameter["description"] == "What a description"
assert my_parameter["default"] == "This is default value"
assert my_parameter["type"] == "string"
my_parameter_bool = init_params_schema["properties"]["my_parameter_bool"]
assert my_parameter_bool["default"] == False
assert my_parameter_bool["type"] == "boolean"
def test_extract_env_vars(mock_tool_extractor):
tool_info = mock_tool_extractor
assert len(tool_info["env_vars"]) == 2
api_key_var, rate_limit_var = tool_info["env_vars"]
assert api_key_var["name"] == "SERPER_API_KEY"
assert api_key_var["description"] == "API key for Serper"
assert api_key_var["required"] == True
assert api_key_var["default"] == None
assert rate_limit_var["name"] == "API_RATE_LIMIT"
assert rate_limit_var["description"] == "API rate limit"
assert rate_limit_var["required"] == False
assert rate_limit_var["default"] == "100"
def test_extract_run_params_schema(mock_tool_extractor):
tool_info = mock_tool_extractor
run_params_schema = tool_info["run_params_schema"]
assert run_params_schema.keys() == {
"properties",
"required",
"title",
"type",
}
query_param = run_params_schema["properties"]["query"]
assert query_param["description"] == "The query parameter"
assert query_param["type"] == "string"
count_param = run_params_schema["properties"]["count"]
assert count_param["type"] == "integer"
assert count_param["default"] == 5
filters_param = run_params_schema["properties"]["filters"]
assert filters_param["description"] == "Optional filters to apply"
assert filters_param["default"] == None
assert filters_param["anyOf"] == [
{"items": {"type": "string"}, "type": "array"},
{"type": "null"},
]
def test_extract_package_dependencies(mock_tool_extractor):
tool_info = mock_tool_extractor
assert tool_info["package_dependencies"] == [
"this-is-a-required-package",
"another-required-package",
]
def test_save_to_json(extractor, tmp_path):
extractor.tools_spec = [
{
"name": "TestTool",
"humanized_name": "Test Tool",
"description": "A test tool",
"run_params_schema": [
{"name": "param1", "description": "Test parameter", "type": "str"}
],
}
]
file_path = tmp_path / "output.json"
extractor.save_to_json(str(file_path))
assert file_path.exists()
with open(file_path, "r") as f:
data = json.load(f)
assert "tools" in data
assert len(data["tools"]) == 1
assert data["tools"][0]["humanized_name"] == "Test Tool"
assert data["tools"][0]["run_params_schema"][0]["name"] == "param1"

View File

@@ -0,0 +1,41 @@
import subprocess
import tempfile
from pathlib import Path
import pytest
@pytest.fixture
def temp_project():
temp_dir = tempfile.TemporaryDirectory()
project_dir = Path(temp_dir.name) / "test_project"
project_dir.mkdir()
pyproject_content = f"""
[project]
name = "test-project"
version = "0.1.0"
description = "Test project"
requires-python = ">=3.10"
"""
(project_dir / "pyproject.toml").write_text(pyproject_content)
run_command(["uv", "add", "--editable", f"file://{Path.cwd().absolute()}"], project_dir)
run_command(["uv", "sync"], project_dir)
yield project_dir
def run_command(cmd, cwd):
return subprocess.run(cmd, cwd=cwd, capture_output=True, text=True)
def test_no_optional_dependencies_in_init(temp_project):
"""
Test that crewai-tools can be imported without optional dependencies.
The package defines optional dependencies in pyproject.toml, but the base
package should be importable without any of these optional dependencies
being installed.
"""
result = run_command(["uv", "run", "python", "-c", "import crewai_tools"], temp_project)
assert result.returncode == 0, f"Import failed with error: {result.stderr}"

0
tests/tools/__init__.py Normal file
View File

View File

@@ -0,0 +1,50 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.brave_search_tool.brave_search_tool import BraveSearchTool
@pytest.fixture
def brave_tool():
return BraveSearchTool(n_results=2)
def test_brave_tool_initialization():
tool = BraveSearchTool()
assert tool.n_results == 10
assert tool.save_file is False
@patch("requests.get")
def test_brave_tool_search(mock_get, brave_tool):
mock_response = {
"web": {
"results": [
{
"title": "Test Title",
"url": "http://test.com",
"description": "Test Description",
}
]
}
}
mock_get.return_value.json.return_value = mock_response
result = brave_tool.run(search_query="test")
assert "Test Title" in result
assert "http://test.com" in result
def test_brave_tool():
tool = BraveSearchTool(
n_results=2,
)
x = tool.run(search_query="ChatGPT")
print(x)
if __name__ == "__main__":
test_brave_tool()
test_brave_tool_initialization()
# test_brave_tool_search(brave_tool)

View File

@@ -0,0 +1,54 @@
import unittest
from unittest.mock import MagicMock, patch
from crewai_tools.tools.brightdata_tool.brightdata_serp import BrightDataSearchTool
class TestBrightDataSearchTool(unittest.TestCase):
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
def setUp(self):
self.tool = BrightDataSearchTool()
@patch("requests.post")
def test_run_successful_search(self, mock_post):
# Sample mock JSON response
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.text = "mock response text"
mock_post.return_value = mock_response
# Define search input
input_data = {
"query": "latest AI news",
"search_engine": "google",
"country": "us",
"language": "en",
"search_type": "nws",
"device_type": "desktop",
"parse_results": True,
"save_file": False,
}
result = self.tool._run(**input_data)
# Assertions
self.assertIsInstance(result, str) # Your tool returns response.text (string)
mock_post.assert_called_once()
@patch("requests.post")
def test_run_with_request_exception(self, mock_post):
mock_post.side_effect = Exception("Timeout")
result = self.tool._run(query="AI", search_engine="google")
self.assertIn("Error", result)
def tearDown(self):
# Clean up env vars
pass
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,64 @@
from unittest.mock import Mock, patch
import requests
from crewai_tools.tools.brightdata_tool.brightdata_unlocker import (
BrightDataWebUnlockerTool,
)
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_success_html(mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "<html><body>Test</body></html>"
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
result = tool._run(url="https://example.com", format="html", save_file=False)
print(result)
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_success_json(mock_post):
mock_response = Mock()
mock_response.status_code = 200
mock_response.text = "mock response text"
mock_response.raise_for_status = Mock()
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
result = tool._run(url="https://example.com", format="json")
assert isinstance(result, str)
@patch.dict(
"os.environ",
{"BRIGHT_DATA_API_KEY": "test_api_key", "BRIGHT_DATA_ZONE": "test_zone"},
)
@patch("crewai_tools.tools.brightdata_tool.brightdata_unlocker.requests.post")
def test_run_http_error(mock_post):
mock_response = Mock()
mock_response.status_code = 403
mock_response.text = "Forbidden"
mock_response.raise_for_status.side_effect = requests.HTTPError(
response=mock_response
)
mock_post.return_value = mock_response
tool = BrightDataWebUnlockerTool()
result = tool._run(url="https://example.com")
assert "HTTP Error" in result
assert "Forbidden" in result

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,365 @@
import pytest
from unittest.mock import MagicMock, patch, ANY
# Mock the couchbase library before importing the tool
# This prevents ImportErrors if couchbase isn't installed in the test environment
mock_couchbase = MagicMock()
mock_couchbase.search = MagicMock()
mock_couchbase.cluster = MagicMock()
mock_couchbase.options = MagicMock()
mock_couchbase.vector_search = MagicMock()
# Simulate the structure needed for checks
mock_couchbase.cluster.Cluster = MagicMock()
mock_couchbase.options.SearchOptions = MagicMock()
mock_couchbase.vector_search.VectorQuery = MagicMock()
mock_couchbase.vector_search.VectorSearch = MagicMock()
mock_couchbase.search.SearchRequest = MagicMock() # Mock the class itself
mock_couchbase.search.SearchRequest.create = MagicMock() # Mock the class method
# Add necessary exception types if needed for testing error handling
class MockCouchbaseException(Exception):
pass
mock_couchbase.exceptions = MagicMock()
mock_couchbase.exceptions.BucketNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.ScopeNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.CollectionNotFoundException = MockCouchbaseException
mock_couchbase.exceptions.IndexNotFoundException = MockCouchbaseException
import sys
sys.modules['couchbase'] = mock_couchbase
sys.modules['couchbase.search'] = mock_couchbase.search
sys.modules['couchbase.cluster'] = mock_couchbase.cluster
sys.modules['couchbase.options'] = mock_couchbase.options
sys.modules['couchbase.vector_search'] = mock_couchbase.vector_search
sys.modules['couchbase.exceptions'] = mock_couchbase.exceptions
# Now import the tool
from crewai_tools.tools.couchbase_tool.couchbase_tool import CouchbaseFTSVectorSearchTool
# --- Test Fixtures ---
@pytest.fixture(autouse=True)
def reset_global_mocks():
"""Reset call counts for globally defined mocks before each test."""
# Reset the specific mock causing the issue
mock_couchbase.vector_search.VectorQuery.reset_mock()
# It's good practice to also reset other related global mocks
# that might be called in your tests to prevent similar issues:
mock_couchbase.vector_search.VectorSearch.from_vector_query.reset_mock()
mock_couchbase.search.SearchRequest.create.reset_mock()
# Additional fixture to handle import pollution in full test suite
@pytest.fixture(autouse=True)
def ensure_couchbase_mocks():
"""Ensure that couchbase imports are properly mocked even when other tests have run first."""
# This fixture ensures our mocks are in place regardless of import order
original_modules = {}
# Store any existing modules
for module_name in ['couchbase', 'couchbase.search', 'couchbase.cluster', 'couchbase.options', 'couchbase.vector_search', 'couchbase.exceptions']:
if module_name in sys.modules:
original_modules[module_name] = sys.modules[module_name]
# Ensure our mocks are active
sys.modules['couchbase'] = mock_couchbase
sys.modules['couchbase.search'] = mock_couchbase.search
sys.modules['couchbase.cluster'] = mock_couchbase.cluster
sys.modules['couchbase.options'] = mock_couchbase.options
sys.modules['couchbase.vector_search'] = mock_couchbase.vector_search
sys.modules['couchbase.exceptions'] = mock_couchbase.exceptions
yield
# Restore original modules if they existed
for module_name, original_module in original_modules.items():
if original_module is not None:
sys.modules[module_name] = original_module
@pytest.fixture
def mock_cluster():
cluster = MagicMock()
bucket_manager = MagicMock()
search_index_manager = MagicMock()
bucket = MagicMock()
scope = MagicMock()
collection = MagicMock()
scope_search_index_manager = MagicMock()
# Setup mock return values for checks
cluster.buckets.return_value = bucket_manager
cluster.search_indexes.return_value = search_index_manager
cluster.bucket.return_value = bucket
bucket.scope.return_value = scope
scope.collection.return_value = collection
scope.search_indexes.return_value = scope_search_index_manager
# Mock bucket existence check
bucket_manager.get_bucket.return_value = True
# Mock scope/collection existence check
mock_scope_spec = MagicMock()
mock_scope_spec.name = "test_scope"
mock_collection_spec = MagicMock()
mock_collection_spec.name = "test_collection"
mock_scope_spec.collections = [mock_collection_spec]
bucket.collections.return_value.get_all_scopes.return_value = [mock_scope_spec]
# Mock index existence check
mock_index_def = MagicMock()
mock_index_def.name = "test_index"
scope_search_index_manager.get_all_indexes.return_value = [mock_index_def]
search_index_manager.get_all_indexes.return_value = [mock_index_def]
return cluster
@pytest.fixture
def mock_embedding_function():
# Simple mock embedding function
# return lambda query: [0.1] * 10 # Example embedding vector
return MagicMock(return_value=[0.1] * 10)
@pytest.fixture
def tool_config(mock_cluster, mock_embedding_function):
return {
"cluster": mock_cluster,
"bucket_name": "test_bucket",
"scope_name": "test_scope",
"collection_name": "test_collection",
"index_name": "test_index",
"embedding_function": mock_embedding_function,
"limit": 5,
"embedding_key": "test_embedding",
"scoped_index": True
}
@pytest.fixture
def couchbase_tool(tool_config):
# Patch COUCHBASE_AVAILABLE to True for these tests
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
tool = CouchbaseFTSVectorSearchTool(**tool_config)
return tool
@pytest.fixture
def mock_search_iter():
mock_iter = MagicMock()
# Simulate search results with a 'fields' attribute
mock_row1 = MagicMock()
mock_row1.fields = {"id": "doc1", "text": "content 1", "test_embedding": [0.1]*10}
mock_row2 = MagicMock()
mock_row2.fields = {"id": "doc2", "text": "content 2", "test_embedding": [0.2]*10}
mock_iter.rows.return_value = [mock_row1, mock_row2]
return mock_iter
# --- Test Cases ---
def test_initialization_success(couchbase_tool, tool_config):
"""Test successful initialization with valid config."""
assert couchbase_tool.cluster == tool_config["cluster"]
assert couchbase_tool.bucket_name == "test_bucket"
assert couchbase_tool.scope_name == "test_scope"
assert couchbase_tool.collection_name == "test_collection"
assert couchbase_tool.index_name == "test_index"
assert couchbase_tool.embedding_function is not None
assert couchbase_tool.limit == 5
assert couchbase_tool.embedding_key == "test_embedding"
assert couchbase_tool.scoped_index == True
# Check if helper methods were called during init (via mocks in fixture)
couchbase_tool.cluster.buckets().get_bucket.assert_called_once_with("test_bucket")
couchbase_tool.cluster.bucket().collections().get_all_scopes.assert_called_once()
couchbase_tool.cluster.bucket().scope().search_indexes().get_all_indexes.assert_called_once()
def test_initialization_missing_required_args(mock_cluster, mock_embedding_function):
"""Test initialization fails when required arguments are missing."""
base_config = {
"cluster": mock_cluster, "bucket_name": "b", "scope_name": "s",
"collection_name": "c", "index_name": "i", "embedding_function": mock_embedding_function
}
required_keys = base_config.keys()
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
for key in required_keys:
incomplete_config = base_config.copy()
del incomplete_config[key]
with pytest.raises(ValueError):
CouchbaseFTSVectorSearchTool(**incomplete_config)
def test_initialization_couchbase_unavailable():
"""Test behavior when couchbase library is not available."""
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', False):
with patch('click.confirm', return_value=False) as mock_confirm:
with pytest.raises(ImportError, match="The 'couchbase' package is required"):
CouchbaseFTSVectorSearchTool(cluster=MagicMock(), bucket_name="b", scope_name="s",
collection_name="c", index_name="i", embedding_function=MagicMock())
mock_confirm.assert_called_once() # Ensure user was prompted
def test_run_success_scoped_index(couchbase_tool, mock_search_iter, tool_config, mock_embedding_function):
"""Test successful _run execution with a scoped index."""
query = "find relevant documents"
# expected_embedding = mock_embedding_function(query)
# Mock the scope search method
couchbase_tool._scope.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery') as mock_vq, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch') as mock_vs, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest') as mock_sr, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions') as mock_so:
# Set up the mock objects and their return values
mock_vector_query = MagicMock()
mock_vector_search = MagicMock()
mock_search_req = MagicMock()
mock_search_options = MagicMock()
mock_vq.return_value = mock_vector_query
mock_vs.from_vector_query.return_value = mock_vector_search
mock_sr.create.return_value = mock_search_req
mock_so.return_value = mock_search_options
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config['embedding_function'].assert_called_once_with(query)
# Check VectorQuery call
mock_vq.assert_called_once_with(
tool_config['embedding_key'], mock_embedding_function.return_value, tool_config['limit']
)
# Check VectorSearch call
mock_vs.from_vector_query.assert_called_once_with(mock_vector_query)
# Check SearchRequest creation
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config['limit'], fields=["*"])
# Check that scope search was called correctly
couchbase_tool._scope.search.assert_called_once_with(
tool_config['index_name'],
mock_search_req,
mock_search_options
)
# Check cluster search was NOT called
couchbase_tool.cluster.search.assert_not_called()
# Check result format (simple check for JSON structure)
assert '"id": "doc1"' in result
assert '"id": "doc2"' in result
assert result.startswith('[') # Should be valid JSON after concatenation
def test_run_success_global_index(tool_config, mock_search_iter, mock_embedding_function):
"""Test successful _run execution with a global (non-scoped) index."""
tool_config['scoped_index'] = False
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
couchbase_tool = CouchbaseFTSVectorSearchTool(**tool_config)
query = "find global documents"
# expected_embedding = mock_embedding_function(query)
# Mock the cluster search method
couchbase_tool.cluster.search = MagicMock(return_value=mock_search_iter)
# Mock the VectorQuery/VectorSearch/SearchRequest creation using runtime patching
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorQuery') as mock_vq, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.VectorSearch') as mock_vs, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.search.SearchRequest') as mock_sr, \
patch('crewai_tools.tools.couchbase_tool.couchbase_tool.SearchOptions') as mock_so:
# Set up the mock objects and their return values
mock_vector_query = MagicMock()
mock_vector_search = MagicMock()
mock_search_req = MagicMock()
mock_search_options = MagicMock()
mock_vq.return_value = mock_vector_query
mock_vs.from_vector_query.return_value = mock_vector_search
mock_sr.create.return_value = mock_search_req
mock_so.return_value = mock_search_options
result = couchbase_tool._run(query=query)
# Check embedding function call
tool_config['embedding_function'].assert_called_once_with(query)
# Check VectorQuery/Search call
mock_vq.assert_called_once_with(
tool_config['embedding_key'], mock_embedding_function.return_value, tool_config['limit']
)
mock_sr.create.assert_called_once_with(mock_vector_search)
# Check SearchOptions creation
mock_so.assert_called_once_with(limit=tool_config['limit'], fields=["*"])
# Check that cluster search was called correctly
couchbase_tool.cluster.search.assert_called_once_with(
tool_config['index_name'],
mock_search_req,
mock_search_options
)
# Check scope search was NOT called
couchbase_tool._scope.search.assert_not_called()
# Check result format
assert '"id": "doc1"' in result
assert '"id": "doc2"' in result
def test_check_bucket_exists_fail(tool_config):
"""Test check for bucket non-existence."""
mock_cluster = tool_config['cluster']
mock_cluster.buckets().get_bucket.side_effect = mock_couchbase.exceptions.BucketNotFoundException("Bucket not found")
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
with pytest.raises(ValueError, match="Bucket test_bucket does not exist."):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_scope_exists_fail(tool_config):
"""Test check for scope non-existence."""
mock_cluster = tool_config['cluster']
# Simulate scope not being in the list returned
mock_scope_spec = MagicMock()
mock_scope_spec.name = "wrong_scope"
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
with pytest.raises(ValueError, match="Scope test_scope not found"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_collection_exists_fail(tool_config):
"""Test check for collection non-existence."""
mock_cluster = tool_config['cluster']
# Simulate collection not being in the scope's list
mock_scope_spec = MagicMock()
mock_scope_spec.name = "test_scope"
mock_collection_spec = MagicMock()
mock_collection_spec.name = "wrong_collection"
mock_scope_spec.collections = [mock_collection_spec] # Only has wrong collection
mock_cluster.bucket().collections().get_all_scopes.return_value = [mock_scope_spec]
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
with pytest.raises(ValueError, match="Collection test_collection not found"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_index_exists_fail_scoped(tool_config):
"""Test check for scoped index non-existence."""
mock_cluster = tool_config['cluster']
# Simulate index not being in the list returned by scope manager
mock_cluster.bucket().scope().search_indexes().get_all_indexes.return_value = []
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
with pytest.raises(ValueError, match="Index test_index does not exist"):
CouchbaseFTSVectorSearchTool(**tool_config)
def test_check_index_exists_fail_global(tool_config):
"""Test check for global index non-existence."""
tool_config['scoped_index'] = False
mock_cluster = tool_config['cluster']
# Simulate index not being in the list returned by cluster manager
mock_cluster.search_indexes().get_all_indexes.return_value = []
with patch('crewai_tools.tools.couchbase_tool.couchbase_tool.COUCHBASE_AVAILABLE', True):
with pytest.raises(ValueError, match="Index test_index does not exist"):
CouchbaseFTSVectorSearchTool(**tool_config)

View File

@@ -0,0 +1,355 @@
import os
import unittest
from unittest.mock import patch, MagicMock
from crewai.tools import BaseTool
from crewai_tools.tools import CrewaiEnterpriseTools
from crewai_tools.adapters.tool_collection import ToolCollection
from crewai_tools.adapters.enterprise_adapter import EnterpriseActionTool
class TestCrewaiEnterpriseTools(unittest.TestCase):
def setUp(self):
self.mock_tools = [
self._create_mock_tool("tool1", "Tool 1 Description"),
self._create_mock_tool("tool2", "Tool 2 Description"),
self._create_mock_tool("tool3", "Tool 3 Description"),
]
self.adapter_patcher = patch(
"crewai_tools.tools.crewai_enterprise_tools.crewai_enterprise_tools.EnterpriseActionKitToolAdapter"
)
self.MockAdapter = self.adapter_patcher.start()
mock_adapter_instance = self.MockAdapter.return_value
mock_adapter_instance.tools.return_value = self.mock_tools
def tearDown(self):
self.adapter_patcher.stop()
def _create_mock_tool(self, name, description):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.name = name
mock_tool.description = description
return mock_tool
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
def test_returns_tool_collection(self):
tools = CrewaiEnterpriseTools()
self.assertIsInstance(tools, ToolCollection)
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
def test_returns_all_tools_when_no_actions_list(self):
tools = CrewaiEnterpriseTools()
self.assertEqual(len(tools), 3)
self.assertEqual(tools[0].name, "tool1")
self.assertEqual(tools[1].name, "tool2")
self.assertEqual(tools[2].name, "tool3")
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
def test_filters_tools_by_actions_list(self):
tools = CrewaiEnterpriseTools(actions_list=["ToOl1", "tool3"])
self.assertEqual(len(tools), 2)
self.assertEqual(tools[0].name, "tool1")
self.assertEqual(tools[1].name, "tool3")
def test_uses_provided_parameters(self):
CrewaiEnterpriseTools(
enterprise_token="test-token",
enterprise_action_kit_project_id="project-id",
enterprise_action_kit_project_url="project-url",
)
self.MockAdapter.assert_called_once_with(
enterprise_action_token="test-token",
enterprise_action_kit_project_id="project-id",
enterprise_action_kit_project_url="project-url",
)
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
def test_uses_environment_token(self):
CrewaiEnterpriseTools()
self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token")
@patch.dict(os.environ, {"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token"})
def test_uses_environment_token_when_no_token_provided(self):
CrewaiEnterpriseTools(enterprise_token="")
self.MockAdapter.assert_called_once_with(enterprise_action_token="env-token")
@patch.dict(
os.environ,
{
"CREWAI_ENTERPRISE_TOOLS_TOKEN": "env-token",
"CREWAI_ENTERPRISE_TOOLS_ACTIONS_LIST": '["tool1", "tool3"]',
},
)
def test_uses_environment_actions_list(self):
tools = CrewaiEnterpriseTools()
self.assertEqual(len(tools), 2)
self.assertEqual(tools[0].name, "tool1")
self.assertEqual(tools[1].name, "tool3")
class TestEnterpriseActionToolSchemaConversion(unittest.TestCase):
"""Test the enterprise action tool schema conversion and validation."""
def setUp(self):
self.test_schema = {
"type": "function",
"function": {
"name": "TEST_COMPLEX_ACTION",
"description": "Test action with complex nested structure",
"parameters": {
"type": "object",
"properties": {
"filterCriteria": {
"type": "object",
"description": "Filter criteria object",
"properties": {
"operation": {"type": "string", "enum": ["AND", "OR"]},
"rules": {
"type": "array",
"items": {
"type": "object",
"properties": {
"field": {
"type": "string",
"enum": ["name", "email", "status"],
},
"operator": {
"type": "string",
"enum": ["equals", "contains"],
},
"value": {"type": "string"},
},
"required": ["field", "operator", "value"],
},
},
},
"required": ["operation", "rules"],
},
"options": {
"type": "object",
"properties": {
"limit": {"type": "integer"},
"offset": {"type": "integer"},
},
"required": [],
},
},
"required": [],
},
},
}
def test_complex_schema_conversion(self):
"""Test that complex nested schemas are properly converted to Pydantic models."""
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
self.assertEqual(tool.name, "gmail_search_for_email")
self.assertEqual(tool.action_name, "GMAIL_SEARCH_FOR_EMAIL")
schema_class = tool.args_schema
self.assertIsNotNone(schema_class)
schema_fields = schema_class.model_fields
self.assertIn("filterCriteria", schema_fields)
self.assertIn("options", schema_fields)
# Test valid input structure
valid_input = {
"filterCriteria": {
"operation": "AND",
"rules": [
{"field": "name", "operator": "contains", "value": "test"},
{"field": "status", "operator": "equals", "value": "active"},
],
},
"options": {"limit": 10},
}
# This should not raise an exception
validated_input = schema_class(**valid_input)
self.assertIsNotNone(validated_input.filterCriteria)
self.assertIsNotNone(validated_input.options)
def test_optional_fields_validation(self):
"""Test that optional fields work correctly."""
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
schema_class = tool.args_schema
minimal_input = {}
validated_input = schema_class(**minimal_input)
self.assertIsNone(validated_input.filterCriteria)
self.assertIsNone(validated_input.options)
partial_input = {"options": {"limit": 10}}
validated_input = schema_class(**partial_input)
self.assertIsNone(validated_input.filterCriteria)
self.assertIsNotNone(validated_input.options)
def test_enum_validation(self):
"""Test that enum values are properly validated."""
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
schema_class = tool.args_schema
invalid_input = {
"filterCriteria": {
"operation": "INVALID_OPERATOR",
"rules": [],
}
}
with self.assertRaises(Exception):
schema_class(**invalid_input)
def test_required_nested_fields(self):
"""Test that required fields in nested objects are validated."""
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
schema_class = tool.args_schema
incomplete_input = {
"filterCriteria": {
"operation": "OR",
"rules": [
{
"field": "name",
"operator": "contains",
}
],
}
}
with self.assertRaises(Exception):
schema_class(**incomplete_input)
@patch("requests.post")
def test_tool_execution_with_complex_input(self, mock_post):
"""Test that the tool can execute with complex validated input."""
mock_response = MagicMock()
mock_response.ok = True
mock_response.json.return_value = {"success": True, "results": []}
mock_post.return_value = mock_response
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
tool._run(
filterCriteria={
"operation": "OR",
"rules": [
{"field": "name", "operator": "contains", "value": "test"},
{"field": "status", "operator": "equals", "value": "active"},
],
},
options={"limit": 10},
)
mock_post.assert_called_once()
call_args = mock_post.call_args
payload = call_args[1]["json"]
self.assertIn("filterCriteria", payload)
self.assertIn("options", payload)
self.assertEqual(payload["filterCriteria"]["operation"], "OR")
def test_model_naming_convention(self):
"""Test that generated model names follow proper conventions."""
tool = EnterpriseActionTool(
name="gmail_search_for_email",
description="Test tool",
enterprise_action_token="test_token",
action_name="GMAIL_SEARCH_FOR_EMAIL",
action_schema=self.test_schema,
)
schema_class = tool.args_schema
self.assertIsNotNone(schema_class)
self.assertTrue(schema_class.__name__.endswith("Schema"))
self.assertTrue(schema_class.__name__[0].isupper())
complex_input = {
"filterCriteria": {
"operation": "OR",
"rules": [
{"field": "name", "operator": "contains", "value": "test"},
{"field": "status", "operator": "equals", "value": "active"},
],
},
"options": {"limit": 10},
}
validated = schema_class(**complex_input)
self.assertIsNotNone(validated.filterCriteria)
def test_simple_schema_with_enums(self):
"""Test a simpler schema with basic enum validation."""
simple_schema = {
"type": "function",
"function": {
"name": "SIMPLE_TEST",
"description": "Simple test function",
"parameters": {
"type": "object",
"properties": {
"status": {
"type": "string",
"enum": ["active", "inactive", "pending"],
},
"priority": {"type": "integer", "enum": [1, 2, 3]},
},
"required": ["status"],
},
},
}
tool = EnterpriseActionTool(
name="simple_test",
description="Simple test tool",
enterprise_action_token="test_token",
action_name="SIMPLE_TEST",
action_schema=simple_schema,
)
schema_class = tool.args_schema
valid_input = {"status": "active", "priority": 2}
validated = schema_class(**valid_input)
self.assertEqual(validated.status, "active")
self.assertEqual(validated.priority, 2)
with self.assertRaises(Exception):
schema_class(status="invalid_status")

View File

@@ -0,0 +1,165 @@
import unittest
from unittest.mock import patch, Mock
import pytest
from crewai_tools.tools.crewai_platform_tools import CrewAIPlatformActionTool
class TestCrewAIPlatformActionTool(unittest.TestCase):
@pytest.fixture
def sample_action_schema(self):
return {
"function": {
"name": "test_action",
"description": "Test action for unit testing",
"parameters": {
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "Message to send"
},
"priority": {
"type": "integer",
"description": "Priority level"
}
},
"required": ["message"]
}
}
}
@pytest.fixture
def platform_action_tool(self, sample_action_schema):
return CrewAIPlatformActionTool(
description="Test Action Tool\nTest description",
action_name="test_action",
action_schema=sample_action_schema
)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
def test_run_success(self, mock_post):
schema = {
"function": {
"name": "test_action",
"description": "Test action",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message"}
},
"required": ["message"]
}
}
}
tool = CrewAIPlatformActionTool(
description="Test tool",
action_name="test_action",
action_schema=schema
)
mock_response = Mock()
mock_response.ok = True
mock_response.json.return_value = {"result": "success", "data": "test_data"}
mock_post.return_value = mock_response
result = tool._run(message="test message")
mock_post.assert_called_once()
_, kwargs = mock_post.call_args
assert "test_action/execute" in kwargs["url"]
assert kwargs["headers"]["Authorization"] == "Bearer test_token"
assert kwargs["json"]["message"] == "test message"
assert "success" in result
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
def test_run_api_error(self, mock_post):
schema = {
"function": {
"name": "test_action",
"description": "Test action",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message"}
},
"required": ["message"]
}
}
}
tool = CrewAIPlatformActionTool(
description="Test tool",
action_name="test_action",
action_schema=schema
)
mock_response = Mock()
mock_response.ok = False
mock_response.json.return_value = {"error": {"message": "Invalid request"}}
mock_post.return_value = mock_response
result = tool._run(message="test message")
assert "API request failed" in result
assert "Invalid request" in result
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_action_tool.requests.post")
def test_run_exception(self, mock_post):
schema = {
"function": {
"name": "test_action",
"description": "Test action",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message"}
},
"required": ["message"]
}
}
}
tool = CrewAIPlatformActionTool(
description="Test tool",
action_name="test_action",
action_schema=schema
)
mock_post.side_effect = Exception("Network error")
result = tool._run(message="test message")
assert "Error executing action test_action: Network error" in result
def test_run_without_token(self):
schema = {
"function": {
"name": "test_action",
"description": "Test action",
"parameters": {
"type": "object",
"properties": {
"message": {"type": "string", "description": "Message"}
},
"required": ["message"]
}
}
}
tool = CrewAIPlatformActionTool(
description="Test tool",
action_name="test_action",
action_schema=schema
)
with patch.dict("os.environ", {}, clear=True):
result = tool._run(message="test message")
assert "Error executing action test_action:" in result
assert "No platform integration token found" in result

View File

@@ -0,0 +1,223 @@
import unittest
from unittest.mock import patch, Mock
import pytest
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformToolBuilder, CrewAIPlatformActionTool
class TestCrewaiPlatformToolBuilder(unittest.TestCase):
@pytest.fixture
def platform_tool_builder(self):
"""Create a CrewaiPlatformToolBuilder instance for testing"""
return CrewaiPlatformToolBuilder(apps=["github", "slack"])
@pytest.fixture
def mock_api_response(self):
return {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string", "description": "Issue title"},
"body": {"type": "string", "description": "Issue body"}
},
"required": ["title"]
}
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {"type": "string", "description": "Channel name"},
"text": {"type": "string", "description": "Message text"}
},
"required": ["channel", "text"]
}
}
]
}
}
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
def test_fetch_actions_success(self, mock_get):
mock_api_response = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string", "description": "Issue title"}
},
"required": ["title"]
}
}
]
}
}
builder = CrewaiPlatformToolBuilder(apps=["github", "slack/send_message"])
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = mock_api_response
mock_get.return_value = mock_response
builder._fetch_actions()
mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert "/actions" in args[0]
assert kwargs["headers"]["Authorization"] == "Bearer test_token"
assert kwargs["params"]["apps"] == "github,slack/send_message"
assert "create_issue" in builder._actions_schema
assert builder._actions_schema["create_issue"]["function"]["name"] == "create_issue"
def test_fetch_actions_no_token(self):
builder = CrewaiPlatformToolBuilder(apps=["github"])
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
builder._fetch_actions()
assert "No platform integration token found" in str(context.exception)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
def test_create_tools(self, mock_get):
mock_api_response = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string", "description": "Issue title"}
},
"required": ["title"]
}
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {"type": "string", "description": "Channel name"}
},
"required": ["channel"]
}
}
]
}
}
builder = CrewaiPlatformToolBuilder(apps=["github", "slack"])
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = mock_api_response
mock_get.return_value = mock_response
tools = builder.tools()
assert len(tools) == 2
assert all(isinstance(tool, CrewAIPlatformActionTool) for tool in tools)
tool_names = [tool.action_name for tool in tools]
assert "create_issue" in tool_names
assert "send_message" in tool_names
github_tool = next((t for t in tools if t.action_name == "create_issue"), None)
slack_tool = next((t for t in tools if t.action_name == "send_message"), None)
assert github_tool is not None
assert slack_tool is not None
assert "Create a GitHub issue" in github_tool.description
assert "Send a Slack message" in slack_tool.description
def test_tools_caching(self):
builder = CrewaiPlatformToolBuilder(apps=["github"])
cached_tools = []
def mock_create_tools():
builder._tools = cached_tools
with patch.object(builder, '_fetch_actions') as mock_fetch, \
patch.object(builder, '_create_tools', side_effect=mock_create_tools) as mock_create:
tools1 = builder.tools()
assert mock_fetch.call_count == 1
assert mock_create.call_count == 1
tools2 = builder.tools()
assert mock_fetch.call_count == 1
assert mock_create.call_count == 1
assert tools1 is tools2
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
def test_empty_apps_list(self):
builder = CrewaiPlatformToolBuilder(apps=[])
with patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get") as mock_get:
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {}}
mock_get.return_value = mock_response
tools = builder.tools()
assert isinstance(tools, list)
assert len(tools) == 0
_, kwargs = mock_get.call_args
assert kwargs["params"]["apps"] == ""
def test_detailed_description_generation(self):
builder = CrewaiPlatformToolBuilder(apps=["test"])
complex_schema = {
"type": "object",
"properties": {
"simple_string": {"type": "string", "description": "A simple string"},
"nested_object": {
"type": "object",
"properties": {
"inner_prop": {"type": "integer", "description": "Inner property"}
},
"description": "Nested object"
},
"array_prop": {
"type": "array",
"items": {"type": "string"},
"description": "Array of strings"
}
}
}
descriptions = builder._generate_detailed_description(complex_schema)
assert isinstance(descriptions, list)
assert len(descriptions) > 0
description_text = "\n".join(descriptions)
assert "simple_string" in description_text
assert "nested_object" in description_text
assert "array_prop" in description_text

View File

@@ -0,0 +1,95 @@
import unittest
from unittest.mock import patch, Mock
from crewai_tools.tools.crewai_platform_tools import CrewaiPlatformTools
class TestCrewaiPlatformTools(unittest.TestCase):
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
def test_crewai_platform_tools_basic(self, mock_get):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {"github": []}}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=["github"])
assert tools is not None
assert isinstance(tools, list)
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
def test_crewai_platform_tools_multiple_apps(self, mock_get):
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
"actions": {
"github": [
{
"name": "create_issue",
"description": "Create a GitHub issue",
"parameters": {
"type": "object",
"properties": {
"title": {"type": "string", "description": "Issue title"},
"body": {"type": "string", "description": "Issue body"}
},
"required": ["title"]
}
}
],
"slack": [
{
"name": "send_message",
"description": "Send a Slack message",
"parameters": {
"type": "object",
"properties": {
"channel": {"type": "string", "description": "Channel to send to"},
"text": {"type": "string", "description": "Message text"}
},
"required": ["channel", "text"]
}
}
]
}
}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=["github", "slack"])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 2
mock_get.assert_called_once()
args, kwargs = mock_get.call_args
assert "apps=github,slack" in args[0] or kwargs.get("params", {}).get("apps") == "github,slack"
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
def test_crewai_platform_tools_empty_apps(self):
with patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get") as mock_get:
mock_response = Mock()
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {"actions": {}}
mock_get.return_value = mock_response
tools = CrewaiPlatformTools(apps=[])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 0
@patch.dict("os.environ", {"CREWAI_PLATFORM_INTEGRATION_TOKEN": "test_token"})
@patch("crewai_tools.tools.crewai_platform_tools.crewai_platform_tool_builder.requests.get")
def test_crewai_platform_tools_api_error_handling(self, mock_get):
mock_get.side_effect = Exception("API Error")
tools = CrewaiPlatformTools(apps=["github"])
assert tools is not None
assert isinstance(tools, list)
assert len(tools) == 0
def test_crewai_platform_tools_no_token(self):
with patch.dict("os.environ", {}, clear=True):
with self.assertRaises(ValueError) as context:
CrewaiPlatformTools(apps=["github"])
assert "No platform integration token found" in str(context.exception)

View File

@@ -0,0 +1,32 @@
import os
from unittest.mock import patch
from crewai_tools import EXASearchTool
import pytest
@pytest.fixture
def exa_search_tool():
return EXASearchTool(api_key="test_api_key")
@pytest.fixture(autouse=True)
def mock_exa_api_key():
with patch.dict(os.environ, {"EXA_API_KEY": "test_key_from_env"}):
yield
def test_exa_search_tool_initialization():
with patch("crewai_tools.tools.exa_tools.exa_search_tool.Exa") as mock_exa_class:
api_key = "test_api_key"
tool = EXASearchTool(api_key=api_key)
assert tool.api_key == api_key
assert tool.content is False
assert tool.summary is False
assert tool.type == "auto"
mock_exa_class.assert_called_once_with(api_key=api_key)
def test_exa_search_tool_initialization_with_env(mock_exa_api_key):
with patch("crewai_tools.tools.exa_tools.exa_search_tool.Exa") as mock_exa_class:
EXASearchTool()
mock_exa_class.assert_called_once_with(api_key="test_key_from_env")

View File

@@ -0,0 +1,187 @@
import os
from unittest.mock import MagicMock, patch
import pytest
import requests
from crewai_tools.tools.generate_crewai_automation_tool.generate_crewai_automation_tool import (
GenerateCrewaiAutomationTool,
GenerateCrewaiAutomationToolSchema,
)
@pytest.fixture(autouse=True)
def mock_env():
with patch.dict(os.environ, {"CREWAI_PERSONAL_ACCESS_TOKEN": "test_token"}):
os.environ.pop("CREWAI_PLUS_URL", None)
yield
@pytest.fixture
def tool():
return GenerateCrewaiAutomationTool()
@pytest.fixture
def custom_url_tool():
with patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom.crewai.com"}):
return GenerateCrewaiAutomationTool()
def test_default_initialization(tool):
assert tool.crewai_enterprise_url == "https://app.crewai.com"
assert tool.personal_access_token == "test_token"
assert tool.name == "Generate CrewAI Automation"
def test_custom_base_url_from_environment(custom_url_tool):
assert custom_url_tool.crewai_enterprise_url == "https://custom.crewai.com"
def test_personal_access_token_from_environment(tool):
assert tool.personal_access_token == "test_token"
def test_valid_prompt_only():
schema = GenerateCrewaiAutomationToolSchema(
prompt="Create a web scraping automation"
)
assert schema.prompt == "Create a web scraping automation"
assert schema.organization_id is None
def test_valid_prompt_with_organization_id():
schema = GenerateCrewaiAutomationToolSchema(
prompt="Create automation", organization_id="org-123"
)
assert schema.prompt == "Create automation"
assert schema.organization_id == "org-123"
def test_empty_prompt_validation():
schema = GenerateCrewaiAutomationToolSchema(prompt="")
assert schema.prompt == ""
assert schema.organization_id is None
@patch("requests.post")
def test_successful_generation_without_org_id(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://app.crewai.com/studio/project-123"
}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation")
assert (
result
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-123"
)
mock_post.assert_called_once_with(
"https://app.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_successful_generation_with_org_id(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://app.crewai.com/studio/project-456"
}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation", organization_id="org-456")
assert (
result
== "Generated CrewAI Studio project URL: https://app.crewai.com/studio/project-456"
)
mock_post.assert_called_once_with(
"https://app.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
"X-Crewai-Organization-Id": "org-456",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_custom_base_url_usage(mock_post, custom_url_tool):
mock_response = MagicMock()
mock_response.json.return_value = {
"url": "https://custom.crewai.com/studio/project-789"
}
mock_post.return_value = mock_response
custom_url_tool.run(prompt="Create automation")
mock_post.assert_called_once_with(
"https://custom.crewai.com/crewai_plus/api/v1/studio",
headers={
"Authorization": "Bearer test_token",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": "Create automation"},
)
@patch("requests.post")
def test_api_error_response_handling(mock_post, tool):
mock_post.return_value.raise_for_status.side_effect = requests.HTTPError(
"400 Bad Request"
)
with pytest.raises(requests.HTTPError):
tool.run(prompt="Create automation")
@patch("requests.post")
def test_network_error_handling(mock_post, tool):
mock_post.side_effect = requests.ConnectionError("Network unreachable")
with pytest.raises(requests.ConnectionError):
tool.run(prompt="Create automation")
@patch("requests.post")
def test_api_response_missing_url(mock_post, tool):
mock_response = MagicMock()
mock_response.json.return_value = {"status": "success"}
mock_post.return_value = mock_response
result = tool.run(prompt="Create automation")
assert result == "Generated CrewAI Studio project URL: None"
def test_authorization_header_construction(tool):
headers = tool._get_headers()
assert headers["Authorization"] == "Bearer test_token"
assert headers["Content-Type"] == "application/json"
assert headers["Accept"] == "application/json"
assert "X-Crewai-Organization-Id" not in headers
def test_authorization_header_with_org_id(tool):
headers = tool._get_headers(organization_id="org-123")
assert headers["Authorization"] == "Bearer test_token"
assert headers["X-Crewai-Organization-Id"] == "org-123"
def test_missing_personal_access_token():
with patch.dict(os.environ, {}, clear=True):
tool = GenerateCrewaiAutomationTool()
assert tool.personal_access_token is None

View File

@@ -0,0 +1,47 @@
import os
import json
from urllib.parse import urlparse
from unittest.mock import patch
import pytest
from crewai_tools.tools.parallel_tools.parallel_search_tool import (
ParallelSearchTool,
)
def test_requires_env_var(monkeypatch):
monkeypatch.delenv("PARALLEL_API_KEY", raising=False)
tool = ParallelSearchTool()
result = tool.run(objective="test")
assert "PARALLEL_API_KEY" in result
@patch("crewai_tools.tools.parallel_tools.parallel_search_tool.requests.post")
def test_happy_path(mock_post, monkeypatch):
monkeypatch.setenv("PARALLEL_API_KEY", "test")
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"search_id": "search_123",
"results": [
{
"url": "https://www.un.org/en/about-us/history-of-the-un",
"title": "History of the United Nations",
"excerpts": [
"Four months after the San Francisco Conference ended, the United Nations officially began, on 24 October 1945..."
],
}
],
}
tool = ParallelSearchTool()
result = tool.run(objective="When was the UN established?", search_queries=["Founding year UN"])
data = json.loads(result)
assert "search_id" in data
urls = [r.get("url", "") for r in data.get("results", [])]
# Validate host against allowed set instead of substring matching
allowed_hosts = {"www.un.org", "un.org"}
assert any(urlparse(u).netloc in allowed_hosts for u in urls)

View File

@@ -0,0 +1,43 @@
import os
from tempfile import NamedTemporaryFile
from typing import cast
from unittest import mock
from pytest import fixture
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
from crewai_tools.tools.rag.rag_tool import RagTool
@fixture(autouse=True)
def mock_embedchain_db_uri():
with NamedTemporaryFile() as tmp:
uri = f"sqlite:///{tmp.name}"
with mock.patch.dict(os.environ, {"EMBEDCHAIN_DB_URI": uri}):
yield
def test_custom_llm_and_embedder():
class MyTool(RagTool):
pass
tool = MyTool(
config=dict(
llm=dict(
provider="openai",
config=dict(model="gpt-3.5-custom"),
),
embedder=dict(
provider="openai",
config=dict(model="text-embedding-3-custom"),
),
)
)
assert tool.adapter is not None
assert isinstance(tool.adapter, EmbedchainAdapter)
adapter = cast(EmbedchainAdapter, tool.adapter)
assert adapter.embedchain_app.llm.config.model == "gpt-3.5-custom"
assert (
adapter.embedchain_app.embedding_model.config.model == "text-embedding-3-custom"
)

View File

@@ -0,0 +1,129 @@
import os
import tempfile
from unittest.mock import MagicMock, patch
from bs4 import BeautifulSoup
from selenium.webdriver.chrome.options import Options
from crewai_tools.tools.selenium_scraping_tool.selenium_scraping_tool import (
SeleniumScrapingTool,
)
def mock_driver_with_html(html_content):
driver = MagicMock()
mock_element = MagicMock()
mock_element.get_attribute.return_value = html_content
bs = BeautifulSoup(html_content, "html.parser")
mock_element.text = bs.get_text()
driver.find_elements.return_value = [mock_element]
driver.find_element.return_value = mock_element
return driver
def initialize_tool_with(mock_driver):
tool = SeleniumScrapingTool(driver=mock_driver)
return tool
@patch("selenium.webdriver.Chrome")
def test_tool_initialization(mocked_chrome):
temp_dir = tempfile.mkdtemp()
mocked_chrome.return_value = MagicMock()
tool = SeleniumScrapingTool()
assert tool.website_url is None
assert tool.css_element is None
assert tool.cookie is None
assert tool.wait_time == 3
assert tool.return_html is False
try:
os.rmdir(temp_dir)
except:
pass
@patch("selenium.webdriver.Chrome")
def test_tool_initialization_with_options(mocked_chrome):
mocked_chrome.return_value = MagicMock()
options = Options()
options.add_argument("--disable-gpu")
SeleniumScrapingTool(options=options)
mocked_chrome.assert_called_once_with(options=options)
@patch("selenium.webdriver.Chrome")
def test_scrape_without_css_selector(_mocked_chrome_driver):
html_content = "<html><body><div>test content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com")
assert "test content" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_css_selector(_mocked_chrome_driver):
html_content = "<html><body><div>test content</div><div class='test'>test content in a specific div</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", css_element="div.test")
assert "test content in a specific div" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_elements.assert_called_with("css selector", "div.test")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_return_html_true(_mocked_chrome_driver):
html_content = "<html><body><div>HTML content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", return_html=True)
assert html_content in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_return_html_false(_mocked_chrome_driver):
html_content = "<html><body><div>HTML content</div></body></html>"
mock_driver = mock_driver_with_html(html_content)
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com", return_html=False)
assert "HTML content" in result
mock_driver.get.assert_called_once_with("https://example.com")
mock_driver.find_element.assert_called_with("tag name", "body")
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_scrape_with_driver_error(_mocked_chrome_driver):
mock_driver = MagicMock()
mock_driver.find_element.side_effect = Exception("WebDriver error occurred")
tool = initialize_tool_with(mock_driver)
result = tool._run(website_url="https://example.com")
assert result == "Error scraping website: WebDriver error occurred"
mock_driver.close.assert_called_once()
@patch("selenium.webdriver.Chrome")
def test_initialization_with_driver(_mocked_chrome_driver):
mock_driver = MagicMock()
tool = initialize_tool_with(mock_driver)
assert tool.driver == mock_driver

View File

@@ -0,0 +1,151 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.serper_dev_tool.serper_dev_tool import SerperDevTool
import os
@pytest.fixture(autouse=True)
def mock_serper_api_key():
with patch.dict(os.environ, {"SERPER_API_KEY": "test_key"}):
yield
@pytest.fixture
def serper_tool():
return SerperDevTool(n_results=2)
def test_serper_tool_initialization():
tool = SerperDevTool()
assert tool.n_results == 10
assert tool.save_file is False
assert tool.search_type == "search"
assert tool.country == ""
assert tool.location == ""
assert tool.locale == ""
def test_serper_tool_custom_initialization():
tool = SerperDevTool(
n_results=5,
save_file=True,
search_type="news",
country="US",
location="New York",
locale="en"
)
assert tool.n_results == 5
assert tool.save_file is True
assert tool.search_type == "news"
assert tool.country == "US"
assert tool.location == "New York"
assert tool.locale == "en"
@patch("requests.post")
def test_serper_tool_search(mock_post):
tool = SerperDevTool(n_results=2)
mock_response = {
"searchParameters": {
"q": "test query",
"type": "search"
},
"organic": [
{
"title": "Test Title 1",
"link": "http://test1.com",
"snippet": "Test Description 1",
"position": 1
},
{
"title": "Test Title 2",
"link": "http://test2.com",
"snippet": "Test Description 2",
"position": 2
}
],
"peopleAlsoAsk": [
{
"question": "Test Question",
"snippet": "Test Answer",
"title": "Test Source",
"link": "http://test.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test query")
assert "searchParameters" in result
assert result["searchParameters"]["q"] == "test query"
assert len(result["organic"]) == 2
assert result["organic"][0]["title"] == "Test Title 1"
@patch("requests.post")
def test_serper_tool_news_search(mock_post):
tool = SerperDevTool(n_results=2, search_type="news")
mock_response = {
"searchParameters": {
"q": "test news",
"type": "news"
},
"news": [
{
"title": "News Title 1",
"link": "http://news1.com",
"snippet": "News Description 1",
"date": "2024-01-01",
"source": "News Source 1",
"imageUrl": "http://image1.com"
}
]
}
mock_post.return_value.json.return_value = mock_response
mock_post.return_value.status_code = 200
result = tool.run(search_query="test news")
assert "news" in result
assert len(result["news"]) == 1
assert result["news"][0]["title"] == "News Title 1"
@patch("requests.post")
def test_serper_tool_with_location_params(mock_post):
tool = SerperDevTool(
n_results=2,
country="US",
location="New York",
locale="en"
)
tool.run(search_query="test")
called_payload = mock_post.call_args.kwargs["json"]
assert called_payload["gl"] == "US"
assert called_payload["location"] == "New York"
assert called_payload["hl"] == "en"
def test_invalid_search_type():
tool = SerperDevTool()
with pytest.raises(ValueError) as exc_info:
tool.run(search_query="test", search_type="invalid")
assert "Invalid search type" in str(exc_info.value)
@patch("requests.post")
def test_api_error_handling(mock_post):
tool = SerperDevTool()
mock_post.side_effect = Exception("API Error")
with pytest.raises(Exception) as exc_info:
tool.run(search_query="test")
assert "API Error" in str(exc_info.value)
if __name__ == "__main__":
pytest.main([__file__])

View File

@@ -0,0 +1,336 @@
import os
from typing import Generator
import pytest
from singlestoredb import connect
from singlestoredb.server import docker
from crewai_tools import SingleStoreSearchTool
from crewai_tools.tools.singlestore_search_tool import SingleStoreSearchToolSchema
@pytest.fixture(scope="session")
def docker_server_url() -> Generator[str, None, None]:
"""Start a SingleStore Docker server for tests."""
try:
sdb = docker.start(license="")
conn = sdb.connect()
curr = conn.cursor()
curr.execute("CREATE DATABASE test_crewai")
curr.close()
conn.close()
yield sdb.connection_url
sdb.stop()
except Exception as e:
pytest.skip(f"Could not start SingleStore Docker container: {e}")
@pytest.fixture(scope="function")
def clean_db_url(docker_server_url) -> Generator[str, None, None]:
"""Provide a clean database URL and clean up tables after test."""
yield docker_server_url
try:
conn = connect(host=docker_server_url, database="test_crewai")
curr = conn.cursor()
curr.execute("SHOW TABLES")
results = curr.fetchall()
for result in results:
curr.execute(f"DROP TABLE {result[0]}")
curr.close()
conn.close()
except Exception:
# Ignore cleanup errors
pass
@pytest.fixture
def sample_table_setup(clean_db_url):
"""Set up sample tables for testing."""
conn = connect(host=clean_db_url, database="test_crewai")
curr = conn.cursor()
# Create sample tables
curr.execute(
"""
CREATE TABLE employees (
id INT PRIMARY KEY,
name VARCHAR(100),
department VARCHAR(50),
salary DECIMAL(10,2)
)
"""
)
curr.execute(
"""
CREATE TABLE departments (
id INT PRIMARY KEY,
name VARCHAR(100),
budget DECIMAL(12,2)
)
"""
)
# Insert sample data
curr.execute(
"""
INSERT INTO employees VALUES
(1, 'Alice Smith', 'Engineering', 75000.00),
(2, 'Bob Johnson', 'Marketing', 65000.00),
(3, 'Carol Davis', 'Engineering', 80000.00)
"""
)
curr.execute(
"""
INSERT INTO departments VALUES
(1, 'Engineering', 500000.00),
(2, 'Marketing', 300000.00)
"""
)
curr.close()
conn.close()
return clean_db_url
class TestSingleStoreSearchTool:
"""Test suite for SingleStoreSearchTool."""
def test_tool_creation_with_connection_params(self, sample_table_setup):
"""Test tool creation with individual connection parameters."""
# Parse URL components for individual parameters
url_parts = sample_table_setup.split("@")[1].split(":")
host = url_parts[0]
port = int(url_parts[1].split("/")[0])
user = "root"
password = sample_table_setup.split("@")[0].split(":")[2]
tool = SingleStoreSearchTool(
tables=[],
host=host,
port=port,
user=user,
password=password,
database="test_crewai",
)
assert tool.name == "Search a database's table(s) content"
assert "SingleStore" in tool.description
assert (
"employees(id int(11), name varchar(100), department varchar(50), salary decimal(10,2))"
in tool.description.lower()
)
assert (
"departments(id int(11), name varchar(100), budget decimal(12,2))"
in tool.description.lower()
)
assert tool.args_schema == SingleStoreSearchToolSchema
assert tool.connection_pool is not None
def test_tool_creation_with_connection_url(self, sample_table_setup):
"""Test tool creation with connection URL."""
tool = SingleStoreSearchTool(host=f"{sample_table_setup}/test_crewai")
assert tool.name == "Search a database's table(s) content"
assert tool.connection_pool is not None
def test_tool_creation_with_specific_tables(self, sample_table_setup):
"""Test tool creation with specific table list."""
tool = SingleStoreSearchTool(
tables=["employees"],
host=sample_table_setup,
database="test_crewai",
)
# Check that description includes specific tables
assert "employees" in tool.description
assert "departments" not in tool.description
def test_tool_creation_with_nonexistent_table(self, sample_table_setup):
"""Test tool creation fails with non-existent table."""
with pytest.raises(ValueError, match="Table nonexistent does not exist"):
SingleStoreSearchTool(
tables=["employees", "nonexistent"],
host=sample_table_setup,
database="test_crewai",
)
def test_tool_creation_with_empty_database(self, clean_db_url):
"""Test tool creation fails with empty database."""
with pytest.raises(ValueError, match="No tables found in the database"):
SingleStoreSearchTool(host=clean_db_url, database="test_crewai")
def test_description_generation(self, sample_table_setup):
"""Test that tool description is properly generated with table info."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
# Check description contains table definitions
assert "employees(" in tool.description
assert "departments(" in tool.description
assert "id int" in tool.description.lower()
assert "name varchar" in tool.description.lower()
def test_query_validation_select_allowed(self, sample_table_setup):
"""Test that SELECT queries are allowed."""
os.environ["SINGLESTOREDB_URL"] = sample_table_setup
tool = SingleStoreSearchTool(database="test_crewai")
valid, message = tool._validate_query("SELECT * FROM employees")
assert valid is True
assert message == "Valid query"
def test_query_validation_show_allowed(self, sample_table_setup):
"""Test that SHOW queries are allowed."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("SHOW TABLES")
assert valid is True
assert message == "Valid query"
def test_query_validation_case_insensitive(self, sample_table_setup):
"""Test that query validation is case insensitive."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, _ = tool._validate_query("select * from employees")
assert valid is True
valid, _ = tool._validate_query("SHOW tables")
assert valid is True
def test_query_validation_insert_denied(self, sample_table_setup):
"""Test that INSERT queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query(
"INSERT INTO employees VALUES (4, 'Test', 'Test', 1000)"
)
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_update_denied(self, sample_table_setup):
"""Test that UPDATE queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("UPDATE employees SET salary = 90000")
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_delete_denied(self, sample_table_setup):
"""Test that DELETE queries are denied."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query("DELETE FROM employees WHERE id = 1")
assert valid is False
assert "Only SELECT and SHOW queries are supported" in message
def test_query_validation_non_string(self, sample_table_setup):
"""Test that non-string queries are rejected."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
valid, message = tool._validate_query(123)
assert valid is False
assert "Search query must be a string" in message
def test_run_select_query(self, sample_table_setup):
"""Test executing a SELECT query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FROM employees ORDER BY id")
assert "Search Results:" in result
assert "Alice Smith" in result
assert "Bob Johnson" in result
assert "Carol Davis" in result
def test_run_filtered_query(self, sample_table_setup):
"""Test executing a filtered SELECT query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run(
"SELECT name FROM employees WHERE department = 'Engineering'"
)
assert "Search Results:" in result
assert "Alice Smith" in result
assert "Carol Davis" in result
assert "Bob Johnson" not in result
def test_run_show_query(self, sample_table_setup):
"""Test executing a SHOW query."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SHOW TABLES")
assert "Search Results:" in result
assert "employees" in result
assert "departments" in result
def test_run_empty_result(self, sample_table_setup):
"""Test executing a query that returns no results."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FROM employees WHERE department = 'NonExistent'")
assert result == "No results found."
def test_run_invalid_query_syntax(self, sample_table_setup):
"""Test executing a query with invalid syntax."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("SELECT * FORM employees") # Intentional typo
assert "Error executing search query:" in result
def test_run_denied_query(self, sample_table_setup):
"""Test that denied queries return appropriate error message."""
tool = SingleStoreSearchTool(host=sample_table_setup, database="test_crewai")
result = tool._run("DELETE FROM employees")
assert "Invalid search query:" in result
assert "Only SELECT and SHOW queries are supported" in result
def test_connection_pool_usage(self, sample_table_setup):
"""Test that connection pooling works correctly."""
tool = SingleStoreSearchTool(
host=sample_table_setup,
database="test_crewai",
pool_size=2,
)
# Execute multiple queries to test pool usage
results = []
for _ in range(5):
result = tool._run("SELECT COUNT(*) FROM employees")
results.append(result)
# All queries should succeed
for result in results:
assert "Search Results:" in result
assert "3" in result # Count of employees
def test_tool_schema_validation(self):
"""Test that the tool schema validation works correctly."""
# Valid input
valid_input = SingleStoreSearchToolSchema(search_query="SELECT * FROM test")
assert valid_input.search_query == "SELECT * FROM test"
# Test that description is present
schema_dict = SingleStoreSearchToolSchema.model_json_schema()
assert "search_query" in schema_dict["properties"]
assert "description" in schema_dict["properties"]["search_query"]
def test_connection_error_handling(self):
"""Test handling of connection errors."""
with pytest.raises(Exception):
# This should fail due to invalid connection parameters
SingleStoreSearchTool(
host="invalid_host",
port=9999,
user="invalid_user",
password="invalid_password",
database="invalid_db",
)

View File

@@ -0,0 +1,103 @@
import asyncio
from unittest.mock import MagicMock, patch
import pytest
from crewai_tools import SnowflakeConfig, SnowflakeSearchTool
# Unit Test Fixtures
@pytest.fixture
def mock_snowflake_connection():
mock_conn = MagicMock()
mock_cursor = MagicMock()
mock_cursor.description = [("col1",), ("col2",)]
mock_cursor.fetchall.return_value = [(1, "value1"), (2, "value2")]
mock_cursor.execute.return_value = None
mock_conn.cursor.return_value = mock_cursor
return mock_conn
@pytest.fixture
def mock_config():
return SnowflakeConfig(
account="test_account",
user="test_user",
password="test_password",
warehouse="test_warehouse",
database="test_db",
snowflake_schema="test_schema",
)
@pytest.fixture
def snowflake_tool(mock_config):
with patch("snowflake.connector.connect") as mock_connect:
tool = SnowflakeSearchTool(config=mock_config)
yield tool
# Unit Tests
@pytest.mark.asyncio
async def test_successful_query_execution(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
results = await snowflake_tool._run(
query="SELECT * FROM test_table", timeout=300
)
assert len(results) == 2
assert results[0]["col1"] == 1
assert results[0]["col2"] == "value1"
mock_snowflake_connection.cursor.assert_called_once()
@pytest.mark.asyncio
async def test_connection_pooling(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Execute multiple queries
await asyncio.gather(
snowflake_tool._run("SELECT 1"),
snowflake_tool._run("SELECT 2"),
snowflake_tool._run("SELECT 3"),
)
# Should reuse connections from pool
assert mock_create_conn.call_count <= snowflake_tool.pool_size
@pytest.mark.asyncio
async def test_cleanup_on_deletion(snowflake_tool, mock_snowflake_connection):
with patch.object(snowflake_tool, "_create_connection") as mock_create_conn:
mock_create_conn.return_value = mock_snowflake_connection
# Add connection to pool
await snowflake_tool._get_connection()
# Return connection to pool
async with snowflake_tool._pool_lock:
snowflake_tool._connection_pool.append(mock_snowflake_connection)
# Trigger cleanup
snowflake_tool.__del__()
mock_snowflake_connection.close.assert_called_once()
def test_config_validation():
# Test missing required fields
with pytest.raises(ValueError):
SnowflakeConfig()
# Test invalid account format
with pytest.raises(ValueError):
SnowflakeConfig(
account="invalid//account", user="test_user", password="test_pass"
)
# Test missing authentication
with pytest.raises(ValueError):
SnowflakeConfig(account="test_account", user="test_user")

View File

@@ -0,0 +1,262 @@
import sys
from unittest.mock import MagicMock, patch
import pytest
# Create mock classes that will be used by our fixture
class MockStagehandModule:
def __init__(self):
self.Stagehand = MagicMock()
self.StagehandConfig = MagicMock()
self.StagehandPage = MagicMock()
class MockStagehandSchemas:
def __init__(self):
self.ActOptions = MagicMock()
self.ExtractOptions = MagicMock()
self.ObserveOptions = MagicMock()
self.AvailableModel = MagicMock()
class MockStagehandUtils:
def __init__(self):
self.configure_logging = MagicMock()
@pytest.fixture(scope="module", autouse=True)
def mock_stagehand_modules():
"""Mock stagehand modules at the start of this test module."""
# Store original modules if they exist
original_modules = {}
for module_name in ["stagehand", "stagehand.schemas", "stagehand.utils"]:
if module_name in sys.modules:
original_modules[module_name] = sys.modules[module_name]
# Create and inject mock modules
mock_stagehand = MockStagehandModule()
mock_stagehand_schemas = MockStagehandSchemas()
mock_stagehand_utils = MockStagehandUtils()
sys.modules["stagehand"] = mock_stagehand
sys.modules["stagehand.schemas"] = mock_stagehand_schemas
sys.modules["stagehand.utils"] = mock_stagehand_utils
# Import after mocking
from crewai_tools.tools.stagehand_tool.stagehand_tool import StagehandResult, StagehandTool
# Make these available to tests in this module
sys.modules[__name__].StagehandResult = StagehandResult
sys.modules[__name__].StagehandTool = StagehandTool
yield
# Restore original modules
for module_name, module in original_modules.items():
sys.modules[module_name] = module
class MockStagehandPage(MagicMock):
def act(self, options):
mock_result = MagicMock()
mock_result.model_dump.return_value = {
"message": "Action completed successfully"
}
return mock_result
def goto(self, url):
return MagicMock()
def extract(self, options):
mock_result = MagicMock()
mock_result.model_dump.return_value = {
"data": "Extracted content",
"metadata": {"source": "test"},
}
return mock_result
def observe(self, options):
result1 = MagicMock()
result1.description = "Button element"
result1.method = "click"
result2 = MagicMock()
result2.description = "Input field"
result2.method = "type"
return [result1, result2]
class MockStagehand(MagicMock):
def init(self):
self.session_id = "test-session-id"
self.page = MockStagehandPage()
def close(self):
pass
@pytest.fixture
def mock_stagehand_instance():
with patch(
"crewai_tools.tools.stagehand_tool.stagehand_tool.Stagehand",
return_value=MockStagehand(),
) as mock:
yield mock
@pytest.fixture
def stagehand_tool():
return StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True, # Enable testing mode to bypass dependency check
)
def test_stagehand_tool_initialization():
"""Test that the StagehandTool initializes with the correct default values."""
tool = StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True, # Enable testing mode
)
assert tool.api_key == "test_api_key"
assert tool.project_id == "test_project_id"
assert tool.model_api_key == "test_model_api_key"
assert tool.headless is False
assert tool.dom_settle_timeout_ms == 3000
assert tool.self_heal is True
assert tool.wait_for_captcha_solves is True
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
def test_act_command(mock_run, stagehand_tool):
"""Test the 'act' command functionality."""
# Setup mock
mock_run.return_value = "Action result: Action completed successfully"
# Run the tool
result = stagehand_tool._run(
instruction="Click the submit button", command_type="act"
)
# Assertions
assert "Action result" in result
assert "Action completed successfully" in result
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
def test_navigate_command(mock_run, stagehand_tool):
"""Test the 'navigate' command functionality."""
# Setup mock
mock_run.return_value = "Successfully navigated to https://example.com"
# Run the tool
result = stagehand_tool._run(
instruction="Go to example.com",
url="https://example.com",
command_type="navigate",
)
# Assertions
assert "https://example.com" in result
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
def test_extract_command(mock_run, stagehand_tool):
"""Test the 'extract' command functionality."""
# Setup mock
mock_run.return_value = "Extracted data: {\"data\": \"Extracted content\", \"metadata\": {\"source\": \"test\"}}"
# Run the tool
result = stagehand_tool._run(
instruction="Extract all product names and prices", command_type="extract"
)
# Assertions
assert "Extracted data" in result
assert "Extracted content" in result
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
def test_observe_command(mock_run, stagehand_tool):
"""Test the 'observe' command functionality."""
# Setup mock
mock_run.return_value = "Element 1: Button element\nSuggested action: click\nElement 2: Input field\nSuggested action: type"
# Run the tool
result = stagehand_tool._run(
instruction="Find all interactive elements", command_type="observe"
)
# Assertions
assert "Element 1: Button element" in result
assert "Element 2: Input field" in result
assert "Suggested action: click" in result
assert "Suggested action: type" in result
@patch("crewai_tools.tools.stagehand_tool.stagehand_tool.StagehandTool._run", autospec=True)
def test_error_handling(mock_run, stagehand_tool):
"""Test error handling in the tool."""
# Setup mock
mock_run.return_value = "Error: Browser automation error"
# Run the tool
result = stagehand_tool._run(
instruction="Click a non-existent button", command_type="act"
)
# Assertions
assert "Error:" in result
assert "Browser automation error" in result
def test_initialization_parameters():
"""Test that the StagehandTool initializes with the correct parameters."""
# Create tool with custom parameters
tool = StagehandTool(
api_key="custom_api_key",
project_id="custom_project_id",
model_api_key="custom_model_api_key",
headless=True,
dom_settle_timeout_ms=5000,
self_heal=False,
wait_for_captcha_solves=False,
verbose=3,
_testing=True, # Enable testing mode
)
# Verify the tool was initialized with the correct parameters
assert tool.api_key == "custom_api_key"
assert tool.project_id == "custom_project_id"
assert tool.model_api_key == "custom_model_api_key"
assert tool.headless is True
assert tool.dom_settle_timeout_ms == 5000
assert tool.self_heal is False
assert tool.wait_for_captcha_solves is False
assert tool.verbose == 3
def test_close_method():
"""Test that the close method cleans up resources correctly."""
# Create the tool with testing mode
tool = StagehandTool(
api_key="test_api_key",
project_id="test_project_id",
model_api_key="test_model_api_key",
_testing=True,
)
# Setup mock stagehand instance
tool._stagehand = MagicMock()
tool._stagehand.close = MagicMock() # Non-async mock
tool._page = MagicMock()
# Call the close method
tool.close()
# Verify resources were cleaned up
assert tool._stagehand is None
assert tool._page is None

View File

@@ -0,0 +1,175 @@
from unittest.mock import patch
import pytest
from crewai_tools.tools.code_interpreter_tool.code_interpreter_tool import (
CodeInterpreterTool,
SandboxPython,
)
@pytest.fixture
def printer_mock():
with patch("crewai_tools.printer.Printer.print") as mock:
yield mock
@pytest.fixture
def docker_unavailable_mock():
with patch(
"crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.CodeInterpreterTool._check_docker_available",
return_value=False,
) as mock:
yield mock
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = "print('Hello, World!')"
libraries_used = ["numpy", "pandas"]
expected_output = "Hello, World!\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker_with_error(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = "print(1/0)"
libraries_used = ["numpy", "pandas"]
expected_output = "Something went wrong while running the code: \nZeroDivisionError: division by zero\n"
docker_mock().containers.run().exec_run().exit_code = 1
docker_mock().containers.run().exec_run().output = (
b"ZeroDivisionError: division by zero\n"
)
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
@patch("crewai_tools.tools.code_interpreter_tool.code_interpreter_tool.docker_from_env")
def test_run_code_in_docker_with_script(docker_mock, printer_mock):
tool = CodeInterpreterTool()
code = """print("This is line 1")
print("This is line 2")"""
libraries_used = []
expected_output = "This is line 1\nThis is line 2\n"
docker_mock().containers.run().exec_run().exit_code = 0
docker_mock().containers.run().exec_run().output = expected_output.encode()
result = tool.run_code_in_docker(code, libraries_used)
assert result == expected_output
printer_mock.assert_called_with(
"Running code in Docker environment", color="bold_blue"
)
def test_restricted_sandbox_basic_code_execution(printer_mock, docker_unavailable_mock):
"""Test basic code execution."""
tool = CodeInterpreterTool()
code = """
result = 2 + 2
print(result)
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert result == 4
def test_restricted_sandbox_running_with_blocked_modules(
printer_mock, docker_unavailable_mock
):
"""Test that restricted modules cannot be imported."""
tool = CodeInterpreterTool()
restricted_modules = SandboxPython.BLOCKED_MODULES
for module in restricted_modules:
code = f"""
import {module}
result = "Import succeeded"
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert f"An error occurred: Importing '{module}' is not allowed" in result
def test_restricted_sandbox_running_with_blocked_builtins(
printer_mock, docker_unavailable_mock
):
"""Test that restricted builtins are not available."""
tool = CodeInterpreterTool()
restricted_builtins = SandboxPython.UNSAFE_BUILTINS
for builtin in restricted_builtins:
code = f"""
{builtin}("test")
result = "Builtin available"
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert f"An error occurred: name '{builtin}' is not defined" in result
def test_restricted_sandbox_running_with_no_result_variable(
printer_mock, docker_unavailable_mock
):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool()
code = """
x = 10
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"Running code in restricted sandbox", color="yellow"
)
assert result == "No result variable found."
def test_unsafe_mode_running_with_no_result_variable(
printer_mock, docker_unavailable_mock
):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool(unsafe_mode=True)
code = """
x = 10
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"WARNING: Running code in unsafe mode", color="bold_magenta"
)
assert result == "No result variable found."
def test_unsafe_mode_running_unsafe_code(printer_mock, docker_unavailable_mock):
"""Test behavior when no result variable is set."""
tool = CodeInterpreterTool(unsafe_mode=True)
code = """
import os
os.system("ls -la")
result = eval("5/1")
"""
result = tool.run(code=code, libraries_used=[])
printer_mock.assert_called_with(
"WARNING: Running code in unsafe mode", color="bold_magenta"
)
assert 5.0 == result

View File

@@ -0,0 +1,10 @@
import pytest
from pydantic.warnings import PydanticDeprecatedSince20
@pytest.mark.filterwarnings("error", category=PydanticDeprecatedSince20)
def test_import_tools_without_pydantic_deprecation_warnings():
# This test is to ensure that the import of crewai_tools does not raise any Pydantic deprecation warnings.
import crewai_tools
assert crewai_tools

View File

@@ -0,0 +1,75 @@
import json
from unittest.mock import patch
import pytest
from crewai_tools import MongoDBVectorSearchConfig, MongoDBVectorSearchTool
# Unit Test Fixtures
@pytest.fixture
def mongodb_vector_search_tool():
tool = MongoDBVectorSearchTool(
connection_string="foo", database_name="bar", collection_name="test"
)
tool._embed_texts = lambda x: [[0.1]]
yield tool
# Unit Tests
def test_successful_query_execution(mongodb_vector_search_tool):
# Enable embedding
with patch.object(mongodb_vector_search_tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
results = json.loads(mongodb_vector_search_tool._run(query="sandwiches"))
assert len(results) == 1
assert results[0]["text"] == "foo"
assert results[0]["_id"] == 1
def test_provide_config():
query_config = MongoDBVectorSearchConfig(limit=10)
tool = MongoDBVectorSearchTool(
connection_string="foo",
database_name="bar",
collection_name="test",
query_config=query_config,
vector_index_name="foo",
embedding_model="bar",
)
tool._embed_texts = lambda x: [[0.1]]
with patch.object(tool._coll, "aggregate") as mock_aggregate:
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
tool._run(query="sandwiches")
assert mock_aggregate.mock_calls[-1].args[0][0]["$vectorSearch"]["limit"] == 10
mock_aggregate.return_value = [dict(text="foo", score=0.1, _id=1)]
def test_cleanup_on_deletion(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool, "_client") as mock_client:
# Trigger cleanup
mongodb_vector_search_tool.__del__()
mock_client.close.assert_called_once()
def test_create_search_index(mongodb_vector_search_tool):
with patch(
"crewai_tools.tools.mongodb_vector_search_tool.vector_search.create_vector_search_index"
) as mock_create_search_index:
mongodb_vector_search_tool.create_vector_search_index(dimensions=10)
kwargs = mock_create_search_index.mock_calls[0].kwargs
assert kwargs["dimensions"] == 10
assert kwargs["similarity"] == "cosine"
def test_add_texts(mongodb_vector_search_tool):
with patch.object(mongodb_vector_search_tool._coll, "bulk_write") as bulk_write:
mongodb_vector_search_tool.add_texts(["foo"])
args = bulk_write.mock_calls[0].args
assert "ReplaceOne" in str(args[0][0])
assert "foo" in str(args[0][0])

View File

@@ -0,0 +1,163 @@
import json
import os
from typing import Type
from unittest.mock import MagicMock
import pytest
from crewai.tools.base_tool import BaseTool
from oxylabs import RealtimeClient
from oxylabs.sources.response import Response as OxylabsResponse
from pydantic import BaseModel
from crewai_tools import (
OxylabsAmazonProductScraperTool,
OxylabsAmazonSearchScraperTool,
OxylabsGoogleSearchScraperTool,
OxylabsUniversalScraperTool,
)
from crewai_tools.tools.oxylabs_amazon_product_scraper_tool.oxylabs_amazon_product_scraper_tool import (
OxylabsAmazonProductScraperConfig,
)
from crewai_tools.tools.oxylabs_google_search_scraper_tool.oxylabs_google_search_scraper_tool import (
OxylabsGoogleSearchScraperConfig,
)
@pytest.fixture
def oxylabs_api() -> RealtimeClient:
oxylabs_api_mock = MagicMock()
html_content = """
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8">
<title>Scraping Sandbox</title>
</head>
<body>
<div id="main">
<div id="product-list">
<div>
<p>Amazing product</p>
<p>Price $14.99</p>
</div>
<div>
<p>Good product</p>
<p>Price $9.99</p>
</div>
</div>
</div>
</body>
</html>
"""
json_content = {
"results": {
"products": [
{"title": "Amazing product", "price": 14.99, "currency": "USD"},
{"title": "Good product", "price": 9.99, "currency": "USD"},
],
},
}
html_response = OxylabsResponse({"results": [{"content": html_content}]})
json_response = OxylabsResponse({"results": [{"content": json_content}]})
oxylabs_api_mock.universal.scrape_url.side_effect = [json_response, html_response]
oxylabs_api_mock.amazon.scrape_search.side_effect = [json_response, html_response]
oxylabs_api_mock.amazon.scrape_product.side_effect = [json_response, html_response]
oxylabs_api_mock.google.scrape_search.side_effect = [json_response, html_response]
return oxylabs_api_mock
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization(tool_class: Type[BaseTool]):
tool = tool_class(username="username", password="password")
assert isinstance(tool, tool_class)
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization_with_env_vars(tool_class: Type[BaseTool]):
os.environ["OXYLABS_USERNAME"] = "username"
os.environ["OXYLABS_PASSWORD"] = "password"
tool = tool_class()
assert isinstance(tool, tool_class)
del os.environ["OXYLABS_USERNAME"]
del os.environ["OXYLABS_PASSWORD"]
@pytest.mark.parametrize(
("tool_class",),
[
(OxylabsUniversalScraperTool,),
(OxylabsAmazonSearchScraperTool,),
(OxylabsGoogleSearchScraperTool,),
(OxylabsAmazonProductScraperTool,),
],
)
def test_tool_initialization_failure(tool_class: Type[BaseTool]):
# making sure env vars are not set
for key in ["OXYLABS_USERNAME", "OXYLABS_PASSWORD"]:
if key in os.environ:
del os.environ[key]
with pytest.raises(ValueError):
tool_class()
@pytest.mark.parametrize(
("tool_class", "tool_config"),
[
(OxylabsUniversalScraperTool, {"geo_location": "Paris, France"}),
(
OxylabsAmazonSearchScraperTool,
{"domain": "co.uk"},
),
(
OxylabsGoogleSearchScraperTool,
OxylabsGoogleSearchScraperConfig(render="html"),
),
(
OxylabsAmazonProductScraperTool,
OxylabsAmazonProductScraperConfig(parse=True),
),
],
)
def test_tool_invocation(
tool_class: Type[BaseTool],
tool_config: BaseModel,
oxylabs_api: RealtimeClient,
):
tool = tool_class(username="username", password="password", config=tool_config)
# setting via __dict__ to bypass pydantic validation
tool.__dict__["oxylabs_api"] = oxylabs_api
# verifying parsed job returns json content
result = tool.run("Scraping Query 1")
assert isinstance(result, str)
assert isinstance(json.loads(result), dict)
# verifying raw job returns str
result = tool.run("Scraping Query 2")
assert isinstance(result, str)
assert "<!DOCTYPE html>" in result

View File

@@ -0,0 +1,309 @@
import os
import tempfile
from pathlib import Path
from unittest.mock import ANY, MagicMock
import pytest
from embedchain.models.data_type import DataType
from crewai_tools.tools import (
CodeDocsSearchTool,
CSVSearchTool,
DirectorySearchTool,
DOCXSearchTool,
GithubSearchTool,
JSONSearchTool,
MDXSearchTool,
PDFSearchTool,
TXTSearchTool,
WebsiteSearchTool,
XMLSearchTool,
YoutubeChannelSearchTool,
YoutubeVideoSearchTool,
)
from crewai_tools.tools.rag.rag_tool import Adapter
pytestmark = [pytest.mark.vcr(filter_headers=["authorization"])]
@pytest.fixture
def mock_adapter():
mock_adapter = MagicMock(spec=Adapter)
return mock_adapter
def test_directory_search_tool():
with tempfile.TemporaryDirectory() as temp_dir:
test_file = Path(temp_dir) / "test.txt"
test_file.write_text("This is a test file for directory search")
tool = DirectorySearchTool(directory=temp_dir)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
def test_pdf_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = PDFSearchTool(pdf="test.pdf", adapter=mock_adapter)
result = tool._run(query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = PDFSearchTool(adapter=mock_adapter)
result = tool._run(pdf="test.pdf", query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.pdf", data_type=DataType.PDF_FILE)
mock_adapter.query.assert_called_once_with("test content")
def test_txt_search_tool():
with tempfile.NamedTemporaryFile(suffix=".txt", delete=False) as temp_file:
temp_file.write(b"This is a test file for txt search")
temp_file_path = temp_file.name
try:
tool = TXTSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test file")
assert "test file" in result.lower()
finally:
os.unlink(temp_file_path)
def test_docx_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = DOCXSearchTool(docx="test.docx", adapter=mock_adapter)
result = tool._run(search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = DOCXSearchTool(adapter=mock_adapter)
result = tool._run(docx="test.docx", search_query="test content")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.docx", data_type=DataType.DOCX)
mock_adapter.query.assert_called_once_with("test content")
def test_json_search_tool():
with tempfile.NamedTemporaryFile(suffix=".json", delete=False) as temp_file:
temp_file.write(b'{"test": "This is a test JSON file"}')
temp_file_path = temp_file.name
try:
tool = JSONSearchTool()
result = tool._run(search_query="test JSON", json_path=temp_file_path)
assert "test json" in result.lower()
finally:
os.unlink(temp_file_path)
def test_xml_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
tool = XMLSearchTool(adapter=mock_adapter)
result = tool._run(search_query="test XML", xml="test.xml")
assert "this is a test" in result.lower()
mock_adapter.add.assert_called_once_with("test.xml")
mock_adapter.query.assert_called_once_with("test XML")
def test_csv_search_tool():
with tempfile.NamedTemporaryFile(suffix=".csv", delete=False) as temp_file:
temp_file.write(b"name,description\ntest,This is a test CSV file")
temp_file_path = temp_file.name
try:
tool = CSVSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test CSV")
assert "test csv" in result.lower()
finally:
os.unlink(temp_file_path)
def test_mdx_search_tool():
with tempfile.NamedTemporaryFile(suffix=".mdx", delete=False) as temp_file:
temp_file.write(b"# Test MDX\nThis is a test MDX file")
temp_file_path = temp_file.name
try:
tool = MDXSearchTool()
tool.add(temp_file_path)
result = tool._run(search_query="test MDX")
assert "test mdx" in result.lower()
finally:
os.unlink(temp_file_path)
def test_website_search_tool(mock_adapter):
mock_adapter.query.return_value = "this is a test"
website = "https://crewai.com"
search_query = "what is crewai?"
tool = WebsiteSearchTool(website=website, adapter=mock_adapter)
result = tool._run(search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = WebsiteSearchTool(adapter=mock_adapter)
result = tool._run(website=website, search_query=search_query)
mock_adapter.query.assert_called_once_with("what is crewai?")
mock_adapter.add.assert_called_once_with(website, data_type=DataType.WEB_PAGE)
assert "this is a test" in result.lower()
def test_youtube_video_search_tool(mock_adapter):
mock_adapter.query.return_value = "some video description"
youtube_video_url = "https://www.youtube.com/watch?v=sample-video-id"
search_query = "what is the video about?"
tool = YoutubeVideoSearchTool(
youtube_video_url=youtube_video_url,
adapter=mock_adapter,
)
result = tool._run(search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeVideoSearchTool(adapter=mock_adapter)
result = tool._run(youtube_video_url=youtube_video_url, search_query=search_query)
assert "some video description" in result
mock_adapter.add.assert_called_once_with(
youtube_video_url, data_type=DataType.YOUTUBE_VIDEO
)
mock_adapter.query.assert_called_once_with(search_query)
def test_youtube_channel_search_tool(mock_adapter):
mock_adapter.query.return_value = "channel description"
youtube_channel_handle = "@crewai"
search_query = "what is the channel about?"
tool = YoutubeChannelSearchTool(
youtube_channel_handle=youtube_channel_handle, adapter=mock_adapter
)
result = tool._run(search_query=search_query)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = YoutubeChannelSearchTool(adapter=mock_adapter)
result = tool._run(
youtube_channel_handle=youtube_channel_handle, search_query=search_query
)
assert "channel description" in result
mock_adapter.add.assert_called_once_with(
youtube_channel_handle, data_type=DataType.YOUTUBE_CHANNEL
)
mock_adapter.query.assert_called_once_with(search_query)
def test_code_docs_search_tool(mock_adapter):
mock_adapter.query.return_value = "test documentation"
docs_url = "https://crewai.com/any-docs-url"
search_query = "test documentation"
tool = CodeDocsSearchTool(docs_url=docs_url, adapter=mock_adapter)
result = tool._run(search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = CodeDocsSearchTool(adapter=mock_adapter)
result = tool._run(docs_url=docs_url, search_query=search_query)
assert "test documentation" in result
mock_adapter.add.assert_called_once_with(docs_url, data_type=DataType.DOCS_SITE)
mock_adapter.query.assert_called_once_with(search_query)
def test_github_search_tool(mock_adapter):
mock_adapter.query.return_value = "repo description"
# ensure the provided repo and content types are used after initialization
tool = GithubSearchTool(
gh_token="test_token",
github_repo="crewai/crewai",
content_types=["code"],
adapter=mock_adapter,
)
result = tool._run(search_query="tell me about crewai repo")
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure content types provided by run call is used
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
content_types=["code", "issue"],
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure default content types are used if not provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(
github_repo="crewai/crewai",
search_query="tell me about crewai repo",
)
assert "repo description" in result
mock_adapter.add.assert_called_once_with(
"repo:crewai/crewai type:code,repo,pr,issue", data_type="github", loader=ANY
)
mock_adapter.query.assert_called_once_with("tell me about crewai repo")
# ensure nothing is added if no repo is provided
mock_adapter.query.reset_mock()
mock_adapter.add.reset_mock()
tool = GithubSearchTool(gh_token="test_token", adapter=mock_adapter)
result = tool._run(search_query="tell me about crewai repo")
mock_adapter.add.assert_not_called()
mock_adapter.query.assert_called_once_with("tell me about crewai repo")

View File

@@ -0,0 +1,231 @@
import unittest
from unittest.mock import MagicMock
from crewai.tools import BaseTool
from crewai_tools.adapters.tool_collection import ToolCollection
class TestToolCollection(unittest.TestCase):
def setUp(self):
self.search_tool = self._create_mock_tool("SearcH", "Search Tool") # Tool name is case sensitive
self.calculator_tool = self._create_mock_tool("calculator", "Calculator Tool")
self.translator_tool = self._create_mock_tool("translator", "Translator Tool")
self.tools = ToolCollection([
self.search_tool,
self.calculator_tool,
self.translator_tool
])
def _create_mock_tool(self, name, description):
mock_tool = MagicMock(spec=BaseTool)
mock_tool.name = name
mock_tool.description = description
return mock_tool
def test_initialization(self):
self.assertEqual(len(self.tools), 3)
self.assertEqual(self.tools[0].name, "SearcH")
self.assertEqual(self.tools[1].name, "calculator")
self.assertEqual(self.tools[2].name, "translator")
def test_empty_initialization(self):
empty_collection = ToolCollection()
self.assertEqual(len(empty_collection), 0)
self.assertEqual(empty_collection._name_cache, {})
def test_initialization_with_none(self):
collection = ToolCollection(None)
self.assertEqual(len(collection), 0)
self.assertEqual(collection._name_cache, {})
def test_access_by_index(self):
self.assertEqual(self.tools[0], self.search_tool)
self.assertEqual(self.tools[1], self.calculator_tool)
self.assertEqual(self.tools[2], self.translator_tool)
def test_access_by_name(self):
self.assertEqual(self.tools["search"], self.search_tool)
self.assertEqual(self.tools["calculator"], self.calculator_tool)
self.assertEqual(self.tools["translator"], self.translator_tool)
def test_key_error_for_invalid_name(self):
with self.assertRaises(KeyError):
_ = self.tools["nonexistent"]
def test_index_error_for_invalid_index(self):
with self.assertRaises(IndexError):
_ = self.tools[10]
def test_negative_index(self):
self.assertEqual(self.tools[-1], self.translator_tool)
self.assertEqual(self.tools[-2], self.calculator_tool)
self.assertEqual(self.tools[-3], self.search_tool)
def test_append(self):
new_tool = self._create_mock_tool("new", "New Tool")
self.tools.append(new_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools[3], new_tool)
self.assertEqual(self.tools["new"], new_tool)
self.assertIn("new", self.tools._name_cache)
def test_append_duplicate_name(self):
duplicate_tool = self._create_mock_tool("search", "Duplicate Search Tool")
self.tools.append(duplicate_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools["search"], duplicate_tool)
def test_extend(self):
new_tools = [
self._create_mock_tool("tool4", "Tool 4"),
self._create_mock_tool("tool5", "Tool 5"),
]
self.tools.extend(new_tools)
self.assertEqual(len(self.tools), 5)
self.assertEqual(self.tools["tool4"], new_tools[0])
self.assertEqual(self.tools["tool5"], new_tools[1])
self.assertIn("tool4", self.tools._name_cache)
self.assertIn("tool5", self.tools._name_cache)
def test_insert(self):
new_tool = self._create_mock_tool("inserted", "Inserted Tool")
self.tools.insert(1, new_tool)
self.assertEqual(len(self.tools), 4)
self.assertEqual(self.tools[1], new_tool)
self.assertEqual(self.tools["inserted"], new_tool)
self.assertIn("inserted", self.tools._name_cache)
def test_remove(self):
self.tools.remove(self.calculator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["calculator"]
self.assertNotIn("calculator", self.tools._name_cache)
def test_remove_nonexistent_tool(self):
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
with self.assertRaises(ValueError):
self.tools.remove(nonexistent_tool)
def test_pop(self):
popped = self.tools.pop(1)
self.assertEqual(popped, self.calculator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["calculator"]
self.assertNotIn("calculator", self.tools._name_cache)
def test_pop_last(self):
popped = self.tools.pop()
self.assertEqual(popped, self.translator_tool)
self.assertEqual(len(self.tools), 2)
with self.assertRaises(KeyError):
_ = self.tools["translator"]
self.assertNotIn("translator", self.tools._name_cache)
def test_clear(self):
self.tools.clear()
self.assertEqual(len(self.tools), 0)
self.assertEqual(self.tools._name_cache, {})
with self.assertRaises(KeyError):
_ = self.tools["search"]
def test_iteration(self):
tools_list = list(self.tools)
self.assertEqual(tools_list, [self.search_tool, self.calculator_tool, self.translator_tool])
def test_contains(self):
self.assertIn(self.search_tool, self.tools)
self.assertIn(self.calculator_tool, self.tools)
self.assertIn(self.translator_tool, self.tools)
nonexistent_tool = self._create_mock_tool("nonexistent", "Nonexistent Tool")
self.assertNotIn(nonexistent_tool, self.tools)
def test_slicing(self):
slice_result = self.tools[1:3]
self.assertEqual(len(slice_result), 2)
self.assertEqual(slice_result[0], self.calculator_tool)
self.assertEqual(slice_result[1], self.translator_tool)
self.assertIsInstance(slice_result, list)
self.assertNotIsInstance(slice_result, ToolCollection)
def test_getitem_with_tool_name_as_int(self):
numeric_name_tool = self._create_mock_tool("123", "Numeric Name Tool")
self.tools.append(numeric_name_tool)
self.assertEqual(self.tools["123"], numeric_name_tool)
with self.assertRaises(IndexError):
_ = self.tools[123]
def test_filter_by_names(self):
filtered = self.tools.filter_by_names(None)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 3)
filtered = self.tools.filter_by_names(["search", "translator"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered[1], self.translator_tool)
self.assertEqual(filtered["search"], self.search_tool)
self.assertEqual(filtered["translator"], self.translator_tool)
filtered = self.tools.filter_by_names(["search", "nonexistent"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0], self.search_tool)
filtered = self.tools.filter_by_names(["nonexistent1", "nonexistent2"])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
filtered = self.tools.filter_by_names([])
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
def test_filter_where(self):
filtered = self.tools.filter_where(lambda tool: tool.name.startswith("S"))
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 1)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered["search"], self.search_tool)
filtered = self.tools.filter_where(lambda tool: True)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 3)
self.assertEqual(filtered[0], self.search_tool)
self.assertEqual(filtered[1], self.calculator_tool)
self.assertEqual(filtered[2], self.translator_tool)
filtered = self.tools.filter_where(lambda tool: False)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 0)
filtered = self.tools.filter_where(lambda tool: len(tool.name) > 8)
self.assertIsInstance(filtered, ToolCollection)
self.assertEqual(len(filtered), 2)
self.assertEqual(filtered[0], self.calculator_tool)
self.assertEqual(filtered[1], self.translator_tool)