mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-09 20:42:35 +00:00
- Added MAX_REGEX_LENGTH to limit regex pattern length and prevent ReDoS. - Introduced allow_unrestricted_paths option to enable searching outside the current working directory. - Implemented brace expansion for glob patterns to support multiple file types. - Enhanced error handling for path traversal and regex compilation. - Updated tests to cover new features and ensure robustness.
391 lines
15 KiB
Python
391 lines
15 KiB
Python
"""Unit tests for GrepTool."""
|
|
|
|
from __future__ import annotations
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from crewai_tools import GrepTool
|
|
from crewai_tools.tools.grep_tool.grep_tool import MAX_REGEX_LENGTH
|
|
|
|
|
|
@pytest.fixture
|
|
def sample_dir(tmp_path: Path) -> Path:
|
|
"""Create a temp directory with sample files for testing."""
|
|
# src/main.py
|
|
src = tmp_path / "src"
|
|
src.mkdir()
|
|
(src / "main.py").write_text(
|
|
"def hello():\n"
|
|
" print('Hello, world!')\n"
|
|
"\n"
|
|
"def goodbye():\n"
|
|
" print('Goodbye, world!')\n"
|
|
"\n"
|
|
"class MyClass:\n"
|
|
" pass\n"
|
|
)
|
|
|
|
# src/utils.py
|
|
(src / "utils.py").write_text(
|
|
"import os\n"
|
|
"\n"
|
|
"def helper():\n"
|
|
" return os.getcwd()\n"
|
|
"\n"
|
|
"CONSTANT = 42\n"
|
|
)
|
|
|
|
# docs/readme.md
|
|
docs = tmp_path / "docs"
|
|
docs.mkdir()
|
|
(docs / "readme.md").write_text(
|
|
"# Project\n"
|
|
"\n"
|
|
"This is a sample project.\n"
|
|
"It has multiple files.\n"
|
|
)
|
|
|
|
# data/binary.bin
|
|
data = tmp_path / "data"
|
|
data.mkdir()
|
|
(data / "binary.bin").write_bytes(b"\x00\x01\x02\x03\x04binary content")
|
|
|
|
# empty.txt
|
|
(tmp_path / "empty.txt").write_text("")
|
|
|
|
# .git/config (should be skipped)
|
|
git_dir = tmp_path / ".git"
|
|
git_dir.mkdir()
|
|
(git_dir / "config").write_text("[core]\n repositoryformatversion = 0\n")
|
|
|
|
return tmp_path
|
|
|
|
|
|
class TestGrepTool:
|
|
"""Tests for GrepTool."""
|
|
|
|
def setup_method(self) -> None:
|
|
"""Set up test fixtures.
|
|
|
|
We use allow_unrestricted_paths=True so that tests using pytest's
|
|
tmp_path (which lives outside the working directory) are not rejected
|
|
by the path-restriction guard.
|
|
"""
|
|
self.tool = GrepTool(allow_unrestricted_paths=True)
|
|
|
|
def test_tool_metadata(self) -> None:
|
|
"""Test tool has correct name and description."""
|
|
assert self.tool.name == "Search file contents"
|
|
assert "search" in self.tool.description.lower() or "Search" in self.tool.description
|
|
|
|
def test_args_schema(self) -> None:
|
|
"""Test that args_schema has correct fields and defaults."""
|
|
schema = self.tool.args_schema
|
|
fields = schema.model_fields
|
|
|
|
assert "pattern" in fields
|
|
assert fields["pattern"].is_required()
|
|
|
|
assert "path" in fields
|
|
assert not fields["path"].is_required()
|
|
|
|
assert "glob_pattern" in fields
|
|
assert not fields["glob_pattern"].is_required()
|
|
|
|
assert "output_mode" in fields
|
|
assert not fields["output_mode"].is_required()
|
|
|
|
assert "case_insensitive" in fields
|
|
assert not fields["case_insensitive"].is_required()
|
|
|
|
assert "context_lines" in fields
|
|
assert not fields["context_lines"].is_required()
|
|
|
|
assert "include_line_numbers" in fields
|
|
assert not fields["include_line_numbers"].is_required()
|
|
|
|
def test_basic_pattern_match(self, sample_dir: Path) -> None:
|
|
"""Test simple string pattern found in output."""
|
|
result = self.tool._run(pattern="Hello", path=str(sample_dir))
|
|
assert "Hello" in result
|
|
|
|
def test_regex_pattern(self, sample_dir: Path) -> None:
|
|
"""Test regex pattern matches function definitions."""
|
|
result = self.tool._run(pattern=r"def\s+\w+", path=str(sample_dir))
|
|
assert "def hello" in result
|
|
assert "def goodbye" in result
|
|
assert "def helper" in result
|
|
|
|
def test_case_sensitive_default(self, sample_dir: Path) -> None:
|
|
"""Test that search is case-sensitive by default."""
|
|
result = self.tool._run(pattern="hello", path=str(sample_dir))
|
|
# "hello" (lowercase) appears in "def hello():" but not in "Hello, world!"
|
|
assert "hello" in result
|
|
# Verify it found the function definition line
|
|
assert "def hello" in result
|
|
|
|
def test_case_insensitive(self, sample_dir: Path) -> None:
|
|
"""Test case-insensitive matching."""
|
|
result = self.tool._run(
|
|
pattern="hello", path=str(sample_dir), case_insensitive=True
|
|
)
|
|
# Should match both "def hello():" and "Hello, world!"
|
|
assert "hello" in result.lower()
|
|
assert "Hello" in result
|
|
|
|
def test_output_mode_content(self, sample_dir: Path) -> None:
|
|
"""Test content output mode shows file paths, line numbers, and text."""
|
|
result = self.tool._run(
|
|
pattern="CONSTANT", path=str(sample_dir), output_mode="content"
|
|
)
|
|
assert "utils.py" in result
|
|
assert "CONSTANT" in result
|
|
# Should have line numbers by default
|
|
assert ": " in result
|
|
|
|
def test_output_mode_files_with_matches(self, sample_dir: Path) -> None:
|
|
"""Test files_with_matches output mode shows only file paths."""
|
|
result = self.tool._run(
|
|
pattern="def", path=str(sample_dir), output_mode="files_with_matches"
|
|
)
|
|
assert "main.py" in result
|
|
assert "utils.py" in result
|
|
# Should not contain line content
|
|
assert "print" not in result
|
|
|
|
def test_output_mode_count(self, sample_dir: Path) -> None:
|
|
"""Test count output mode shows filepath: N format."""
|
|
result = self.tool._run(
|
|
pattern="def", path=str(sample_dir), output_mode="count"
|
|
)
|
|
# main.py has 2 def lines, utils.py has 1
|
|
assert "main.py: 2" in result
|
|
assert "utils.py: 1" in result
|
|
|
|
def test_context_lines(self, sample_dir: Path) -> None:
|
|
"""Test surrounding context lines are included."""
|
|
result = self.tool._run(
|
|
pattern="CONSTANT", path=str(sample_dir), context_lines=2
|
|
)
|
|
# Two lines before CONSTANT = 42 is " return os.getcwd()"
|
|
assert "return os.getcwd()" in result
|
|
assert "CONSTANT" in result
|
|
|
|
def test_line_numbers_disabled(self, sample_dir: Path) -> None:
|
|
"""Test output without line number prefixes."""
|
|
result = self.tool._run(
|
|
pattern="CONSTANT",
|
|
path=str(sample_dir),
|
|
include_line_numbers=False,
|
|
)
|
|
assert "CONSTANT = 42" in result
|
|
# Verify no line number prefix (e.g., "6: ")
|
|
for line in result.strip().split("\n"):
|
|
if "CONSTANT" in line:
|
|
assert not line[0].isdigit() or ": " not in line
|
|
|
|
def test_glob_pattern_filtering(self, sample_dir: Path) -> None:
|
|
"""Test glob pattern filters to specific file types."""
|
|
result = self.tool._run(
|
|
pattern="project",
|
|
path=str(sample_dir),
|
|
glob_pattern="*.py",
|
|
case_insensitive=True,
|
|
)
|
|
# "project" appears in readme.md but not in .py files
|
|
assert "No matches found" in result
|
|
|
|
def test_search_single_file(self, sample_dir: Path) -> None:
|
|
"""Test searching a single file by path."""
|
|
file_path = str(sample_dir / "src" / "main.py")
|
|
result = self.tool._run(pattern="def", path=file_path)
|
|
assert "def hello" in result
|
|
assert "def goodbye" in result
|
|
# Should not include results from other files
|
|
assert "helper" not in result
|
|
|
|
def test_path_not_found(self) -> None:
|
|
"""Test error message when a relative path doesn't exist."""
|
|
result = self.tool._run(pattern="test", path="totally_nonexistent_subdir")
|
|
assert "Error" in result
|
|
assert "does not exist" in result
|
|
|
|
def test_invalid_regex(self, sample_dir: Path) -> None:
|
|
"""Test error message for invalid regex patterns."""
|
|
result = self.tool._run(pattern="[invalid", path=str(sample_dir))
|
|
assert "Error" in result
|
|
assert "Invalid regex" in result
|
|
|
|
def test_binary_files_skipped(self, sample_dir: Path) -> None:
|
|
"""Test binary files are not included in results."""
|
|
result = self.tool._run(pattern="binary", path=str(sample_dir))
|
|
# binary.bin has null bytes so it should be skipped
|
|
assert "binary.bin" not in result
|
|
|
|
def test_no_matches_found(self, sample_dir: Path) -> None:
|
|
"""Test message when no matches are found."""
|
|
result = self.tool._run(
|
|
pattern="zzz_nonexistent_pattern_zzz", path=str(sample_dir)
|
|
)
|
|
assert "No matches found" in result
|
|
|
|
def test_hidden_dirs_skipped(self, sample_dir: Path) -> None:
|
|
"""Test that .git/ directory contents are not searched."""
|
|
result = self.tool._run(pattern="repositoryformatversion", path=str(sample_dir))
|
|
assert "No matches found" in result
|
|
|
|
def test_empty_file(self, sample_dir: Path) -> None:
|
|
"""Test searching an empty file doesn't crash."""
|
|
result = self.tool._run(
|
|
pattern="anything", path=str(sample_dir / "empty.txt")
|
|
)
|
|
assert "No matches found" in result
|
|
|
|
def test_run_with_kwargs(self, sample_dir: Path) -> None:
|
|
"""Test _run ignores extra kwargs."""
|
|
result = self.tool._run(
|
|
pattern="Hello", path=str(sample_dir), extra_arg="ignored"
|
|
)
|
|
assert "Hello" in result
|
|
|
|
|
|
class TestPathRestriction:
|
|
"""Tests for path traversal prevention and allow_unrestricted_paths."""
|
|
|
|
def test_absolute_path_outside_cwd_blocked(self, tmp_path: Path) -> None:
|
|
"""An absolute path outside cwd is rejected by default."""
|
|
tool = GrepTool()
|
|
# tmp_path is almost certainly not under os.getcwd()
|
|
result = tool._run(pattern="anything", path=str(tmp_path))
|
|
assert "Error" in result
|
|
assert "outside the working directory" in result
|
|
|
|
def test_relative_traversal_blocked(self, sample_dir: Path) -> None:
|
|
"""A relative path with ../ that escapes cwd is rejected."""
|
|
tool = GrepTool()
|
|
result = tool._run(pattern="anything", path="../../etc")
|
|
assert "Error" in result
|
|
assert "outside the working directory" in result
|
|
|
|
def test_relative_path_within_cwd_allowed(self) -> None:
|
|
"""A relative path that stays inside cwd works fine."""
|
|
tool = GrepTool()
|
|
# "." is always within cwd
|
|
result = tool._run(pattern="zzz_will_not_match_anything_zzz", path=".")
|
|
# Should not get a traversal error — either matches or "No matches found"
|
|
assert "outside the working directory" not in result
|
|
|
|
def test_allow_unrestricted_paths_bypasses_check(self, tmp_path: Path) -> None:
|
|
"""With allow_unrestricted_paths=True, absolute paths outside cwd are allowed."""
|
|
# Write a searchable file in tmp_path
|
|
(tmp_path / "hello.txt").write_text("unrestricted search target\n")
|
|
tool = GrepTool(allow_unrestricted_paths=True)
|
|
result = tool._run(pattern="unrestricted", path=str(tmp_path))
|
|
assert "unrestricted search target" in result
|
|
|
|
def test_allow_unrestricted_defaults_false(self) -> None:
|
|
"""The flag defaults to False."""
|
|
tool = GrepTool()
|
|
assert tool.allow_unrestricted_paths is False
|
|
|
|
def test_error_message_includes_hint(self, tmp_path: Path) -> None:
|
|
"""The traversal error tells the user how to opt in."""
|
|
tool = GrepTool()
|
|
result = tool._run(pattern="x", path=str(tmp_path))
|
|
assert "GrepTool(allow_unrestricted_paths=True)" in result
|
|
|
|
|
|
class TestReDoSGuards:
|
|
"""Tests for regex denial-of-service mitigations."""
|
|
|
|
def test_pattern_length_rejected(self, sample_dir: Path) -> None:
|
|
"""Patterns exceeding MAX_REGEX_LENGTH are rejected before compilation."""
|
|
tool = GrepTool(allow_unrestricted_paths=True)
|
|
long_pattern = "a" * (MAX_REGEX_LENGTH + 1)
|
|
result = tool._run(pattern=long_pattern, path=str(sample_dir))
|
|
assert "Error" in result
|
|
assert "Pattern too long" in result
|
|
|
|
def test_pattern_at_max_length_accepted(self, sample_dir: Path) -> None:
|
|
"""A pattern exactly at MAX_REGEX_LENGTH is allowed (boundary check)."""
|
|
tool = GrepTool(allow_unrestricted_paths=True)
|
|
exact_pattern = "a" * MAX_REGEX_LENGTH
|
|
result = tool._run(pattern=exact_pattern, path=str(sample_dir))
|
|
# Should not get a length error — either matches or "No matches found"
|
|
assert "Pattern too long" not in result
|
|
|
|
def test_safe_search_returns_match(self) -> None:
|
|
"""_safe_search returns a match object for a normal pattern."""
|
|
compiled = __import__("re").compile(r"hello")
|
|
match = GrepTool._safe_search(compiled, "say hello world")
|
|
assert match is not None
|
|
assert match.group() == "hello"
|
|
|
|
def test_safe_search_returns_none_on_no_match(self) -> None:
|
|
"""_safe_search returns None when the pattern doesn't match."""
|
|
compiled = __import__("re").compile(r"zzz")
|
|
match = GrepTool._safe_search(compiled, "hello world")
|
|
assert match is None
|
|
|
|
|
|
class TestBraceExpansion:
|
|
"""Tests for glob brace expansion ({a,b} syntax)."""
|
|
|
|
def test_expand_simple_brace(self) -> None:
|
|
"""*.{py,txt} expands to ['*.py', '*.txt']."""
|
|
result = GrepTool._expand_brace_pattern("*.{py,txt}")
|
|
assert result == ["*.py", "*.txt"]
|
|
|
|
def test_expand_three_alternatives(self) -> None:
|
|
"""*.{py,txt,md} expands to three patterns."""
|
|
result = GrepTool._expand_brace_pattern("*.{py,txt,md}")
|
|
assert result == ["*.py", "*.txt", "*.md"]
|
|
|
|
def test_expand_no_braces_passthrough(self) -> None:
|
|
"""A pattern without braces is returned as a single-element list."""
|
|
result = GrepTool._expand_brace_pattern("*.py")
|
|
assert result == ["*.py"]
|
|
|
|
def test_expand_strips_whitespace(self) -> None:
|
|
"""Whitespace around alternatives inside braces is stripped."""
|
|
result = GrepTool._expand_brace_pattern("*.{ py , txt }")
|
|
assert result == ["*.py", "*.txt"]
|
|
|
|
def test_expand_prefix_and_suffix(self) -> None:
|
|
"""Prefix and suffix around the braces are preserved."""
|
|
result = GrepTool._expand_brace_pattern("src/*.{py,pyi}.bak")
|
|
assert result == ["src/*.py.bak", "src/*.pyi.bak"]
|
|
|
|
def test_brace_glob_end_to_end(self, tmp_path: Path) -> None:
|
|
"""Brace expansion works end-to-end with _collect_files."""
|
|
(tmp_path / "a.py").write_text("match_me\n")
|
|
(tmp_path / "b.txt").write_text("match_me\n")
|
|
(tmp_path / "c.md").write_text("match_me\n")
|
|
|
|
tool = GrepTool(allow_unrestricted_paths=True)
|
|
result = tool._run(
|
|
pattern="match_me",
|
|
path=str(tmp_path),
|
|
glob_pattern="*.{py,txt}",
|
|
)
|
|
assert "a.py" in result
|
|
assert "b.txt" in result
|
|
# .md should NOT be included
|
|
assert "c.md" not in result
|
|
|
|
def test_brace_glob_no_duplicates(self, tmp_path: Path) -> None:
|
|
"""Files are not reported twice when they match multiple expanded patterns."""
|
|
(tmp_path / "x.py").write_text("unique_content\n")
|
|
|
|
tool = GrepTool(allow_unrestricted_paths=True)
|
|
result = tool._run(
|
|
pattern="unique_content",
|
|
path=str(tmp_path),
|
|
glob_pattern="*.{py,py}",
|
|
output_mode="count",
|
|
)
|
|
# Should appear exactly once
|
|
assert result.count("x.py") == 1
|