mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-04 08:42:38 +00:00
feat: enhance GrepTool with regex length limit, path restrictions, and brace expansion support
- Added MAX_REGEX_LENGTH to limit regex pattern length and prevent ReDoS. - Introduced allow_unrestricted_paths option to enable searching outside the current working directory. - Implemented brace expansion for glob patterns to support multiple file types. - Enhanced error handling for path traversal and regex compilation. - Updated tests to cover new features and ensure robustness.
This commit is contained in:
@@ -1,10 +1,13 @@
|
||||
"""Unit tests for GrepTool."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_tools import GrepTool
|
||||
from crewai_tools.tools.grep_tool.grep_tool import MAX_REGEX_LENGTH
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -64,8 +67,13 @@ class TestGrepTool:
|
||||
"""Tests for GrepTool."""
|
||||
|
||||
def setup_method(self) -> None:
|
||||
"""Set up test fixtures."""
|
||||
self.tool = GrepTool()
|
||||
"""Set up test fixtures.
|
||||
|
||||
We use allow_unrestricted_paths=True so that tests using pytest's
|
||||
tmp_path (which lives outside the working directory) are not rejected
|
||||
by the path-restriction guard.
|
||||
"""
|
||||
self.tool = GrepTool(allow_unrestricted_paths=True)
|
||||
|
||||
def test_tool_metadata(self) -> None:
|
||||
"""Test tool has correct name and description."""
|
||||
@@ -199,8 +207,8 @@ class TestGrepTool:
|
||||
assert "helper" not in result
|
||||
|
||||
def test_path_not_found(self) -> None:
|
||||
"""Test error message when path doesn't exist."""
|
||||
result = self.tool._run(pattern="test", path="/nonexistent/path")
|
||||
"""Test error message when a relative path doesn't exist."""
|
||||
result = self.tool._run(pattern="test", path="totally_nonexistent_subdir")
|
||||
assert "Error" in result
|
||||
assert "does not exist" in result
|
||||
|
||||
@@ -241,3 +249,142 @@ class TestGrepTool:
|
||||
pattern="Hello", path=str(sample_dir), extra_arg="ignored"
|
||||
)
|
||||
assert "Hello" in result
|
||||
|
||||
|
||||
class TestPathRestriction:
|
||||
"""Tests for path traversal prevention and allow_unrestricted_paths."""
|
||||
|
||||
def test_absolute_path_outside_cwd_blocked(self, tmp_path: Path) -> None:
|
||||
"""An absolute path outside cwd is rejected by default."""
|
||||
tool = GrepTool()
|
||||
# tmp_path is almost certainly not under os.getcwd()
|
||||
result = tool._run(pattern="anything", path=str(tmp_path))
|
||||
assert "Error" in result
|
||||
assert "outside the working directory" in result
|
||||
|
||||
def test_relative_traversal_blocked(self, sample_dir: Path) -> None:
|
||||
"""A relative path with ../ that escapes cwd is rejected."""
|
||||
tool = GrepTool()
|
||||
result = tool._run(pattern="anything", path="../../etc")
|
||||
assert "Error" in result
|
||||
assert "outside the working directory" in result
|
||||
|
||||
def test_relative_path_within_cwd_allowed(self) -> None:
|
||||
"""A relative path that stays inside cwd works fine."""
|
||||
tool = GrepTool()
|
||||
# "." is always within cwd
|
||||
result = tool._run(pattern="zzz_will_not_match_anything_zzz", path=".")
|
||||
# Should not get a traversal error — either matches or "No matches found"
|
||||
assert "outside the working directory" not in result
|
||||
|
||||
def test_allow_unrestricted_paths_bypasses_check(self, tmp_path: Path) -> None:
|
||||
"""With allow_unrestricted_paths=True, absolute paths outside cwd are allowed."""
|
||||
# Write a searchable file in tmp_path
|
||||
(tmp_path / "hello.txt").write_text("unrestricted search target\n")
|
||||
tool = GrepTool(allow_unrestricted_paths=True)
|
||||
result = tool._run(pattern="unrestricted", path=str(tmp_path))
|
||||
assert "unrestricted search target" in result
|
||||
|
||||
def test_allow_unrestricted_defaults_false(self) -> None:
|
||||
"""The flag defaults to False."""
|
||||
tool = GrepTool()
|
||||
assert tool.allow_unrestricted_paths is False
|
||||
|
||||
def test_error_message_includes_hint(self, tmp_path: Path) -> None:
|
||||
"""The traversal error tells the user how to opt in."""
|
||||
tool = GrepTool()
|
||||
result = tool._run(pattern="x", path=str(tmp_path))
|
||||
assert "GrepTool(allow_unrestricted_paths=True)" in result
|
||||
|
||||
|
||||
class TestReDoSGuards:
|
||||
"""Tests for regex denial-of-service mitigations."""
|
||||
|
||||
def test_pattern_length_rejected(self, sample_dir: Path) -> None:
|
||||
"""Patterns exceeding MAX_REGEX_LENGTH are rejected before compilation."""
|
||||
tool = GrepTool(allow_unrestricted_paths=True)
|
||||
long_pattern = "a" * (MAX_REGEX_LENGTH + 1)
|
||||
result = tool._run(pattern=long_pattern, path=str(sample_dir))
|
||||
assert "Error" in result
|
||||
assert "Pattern too long" in result
|
||||
|
||||
def test_pattern_at_max_length_accepted(self, sample_dir: Path) -> None:
|
||||
"""A pattern exactly at MAX_REGEX_LENGTH is allowed (boundary check)."""
|
||||
tool = GrepTool(allow_unrestricted_paths=True)
|
||||
exact_pattern = "a" * MAX_REGEX_LENGTH
|
||||
result = tool._run(pattern=exact_pattern, path=str(sample_dir))
|
||||
# Should not get a length error — either matches or "No matches found"
|
||||
assert "Pattern too long" not in result
|
||||
|
||||
def test_safe_search_returns_match(self) -> None:
|
||||
"""_safe_search returns a match object for a normal pattern."""
|
||||
compiled = __import__("re").compile(r"hello")
|
||||
match = GrepTool._safe_search(compiled, "say hello world")
|
||||
assert match is not None
|
||||
assert match.group() == "hello"
|
||||
|
||||
def test_safe_search_returns_none_on_no_match(self) -> None:
|
||||
"""_safe_search returns None when the pattern doesn't match."""
|
||||
compiled = __import__("re").compile(r"zzz")
|
||||
match = GrepTool._safe_search(compiled, "hello world")
|
||||
assert match is None
|
||||
|
||||
|
||||
class TestBraceExpansion:
|
||||
"""Tests for glob brace expansion ({a,b} syntax)."""
|
||||
|
||||
def test_expand_simple_brace(self) -> None:
|
||||
"""*.{py,txt} expands to ['*.py', '*.txt']."""
|
||||
result = GrepTool._expand_brace_pattern("*.{py,txt}")
|
||||
assert result == ["*.py", "*.txt"]
|
||||
|
||||
def test_expand_three_alternatives(self) -> None:
|
||||
"""*.{py,txt,md} expands to three patterns."""
|
||||
result = GrepTool._expand_brace_pattern("*.{py,txt,md}")
|
||||
assert result == ["*.py", "*.txt", "*.md"]
|
||||
|
||||
def test_expand_no_braces_passthrough(self) -> None:
|
||||
"""A pattern without braces is returned as a single-element list."""
|
||||
result = GrepTool._expand_brace_pattern("*.py")
|
||||
assert result == ["*.py"]
|
||||
|
||||
def test_expand_strips_whitespace(self) -> None:
|
||||
"""Whitespace around alternatives inside braces is stripped."""
|
||||
result = GrepTool._expand_brace_pattern("*.{ py , txt }")
|
||||
assert result == ["*.py", "*.txt"]
|
||||
|
||||
def test_expand_prefix_and_suffix(self) -> None:
|
||||
"""Prefix and suffix around the braces are preserved."""
|
||||
result = GrepTool._expand_brace_pattern("src/*.{py,pyi}.bak")
|
||||
assert result == ["src/*.py.bak", "src/*.pyi.bak"]
|
||||
|
||||
def test_brace_glob_end_to_end(self, tmp_path: Path) -> None:
|
||||
"""Brace expansion works end-to-end with _collect_files."""
|
||||
(tmp_path / "a.py").write_text("match_me\n")
|
||||
(tmp_path / "b.txt").write_text("match_me\n")
|
||||
(tmp_path / "c.md").write_text("match_me\n")
|
||||
|
||||
tool = GrepTool(allow_unrestricted_paths=True)
|
||||
result = tool._run(
|
||||
pattern="match_me",
|
||||
path=str(tmp_path),
|
||||
glob_pattern="*.{py,txt}",
|
||||
)
|
||||
assert "a.py" in result
|
||||
assert "b.txt" in result
|
||||
# .md should NOT be included
|
||||
assert "c.md" not in result
|
||||
|
||||
def test_brace_glob_no_duplicates(self, tmp_path: Path) -> None:
|
||||
"""Files are not reported twice when they match multiple expanded patterns."""
|
||||
(tmp_path / "x.py").write_text("unique_content\n")
|
||||
|
||||
tool = GrepTool(allow_unrestricted_paths=True)
|
||||
result = tool._run(
|
||||
pattern="unique_content",
|
||||
path=str(tmp_path),
|
||||
glob_pattern="*.{py,py}",
|
||||
output_mode="count",
|
||||
)
|
||||
# Should appear exactly once
|
||||
assert result.count("x.py") == 1
|
||||
|
||||
Reference in New Issue
Block a user