feat: enhance GrepTool with regex length limit, path restrictions, and brace expansion support

- Added MAX_REGEX_LENGTH to limit regex pattern length and prevent ReDoS.
- Introduced allow_unrestricted_paths option to enable searching outside the current working directory.
- Implemented brace expansion for glob patterns to support multiple file types.
- Enhanced error handling for path traversal and regex compilation.
- Updated tests to cover new features and ensure robustness.
This commit is contained in:
lorenzejay
2026-02-11 20:44:46 -08:00
parent 25835ca795
commit ec2b6a0287
2 changed files with 254 additions and 15 deletions

View File

@@ -2,20 +2,26 @@
from __future__ import annotations
import os
import re
from dataclasses import dataclass, field
from itertools import chain
import os
from pathlib import Path
import re
import signal
import sys
from typing import Literal
from crewai.tools import BaseTool
from pydantic import BaseModel, Field
MAX_OUTPUT_CHARS = 50_000
MAX_FILES = 10_000
MAX_MATCHES_PER_FILE = 200
MAX_LINE_LENGTH = 500
BINARY_CHECK_SIZE = 8192
MAX_REGEX_LENGTH = 1_000
REGEX_MATCH_TIMEOUT_SECONDS = 5
SKIP_DIRS = frozenset(
{
@@ -61,7 +67,7 @@ class GrepToolSchema(BaseModel):
)
glob_pattern: str | None = Field(
default=None,
description="Glob pattern to filter files (e.g. '*.py', '*.{ts,tsx}')",
description="Glob pattern to filter files (e.g. '*.py'). Supports brace expansion (e.g. '*.{ts,tsx}').",
)
output_mode: Literal["content", "files_with_matches", "count"] = Field(
default="content",
@@ -89,13 +95,16 @@ class GrepTool(BaseTool):
Example:
>>> tool = GrepTool()
>>> result = tool.run(pattern="def.*main", path="/path/to/project")
>>> result = tool.run(pattern="def.*main", path="src")
>>> result = tool.run(
... pattern="TODO",
... path="/path/to/project",
... glob_pattern="*.py",
... context_lines=2,
... )
To search any path on the filesystem (opt-in):
>>> tool = GrepTool(allow_unrestricted_paths=True)
>>> result = tool.run(pattern="error", path="/var/log/app")
"""
name: str = "Search file contents"
@@ -105,6 +114,13 @@ class GrepTool(BaseTool):
"Returns matching content with line numbers, file paths only, or match counts."
)
args_schema: type[BaseModel] = GrepToolSchema
allow_unrestricted_paths: bool = Field(
default=False,
description=(
"When False (default), searches are restricted to the current working "
"directory. Set to True to allow searching any path on the filesystem."
),
)
def _run(
self,
@@ -131,12 +147,31 @@ class GrepTool(BaseTool):
Returns:
Formatted search results as a string.
"""
# Resolve search path
search_path = Path(path) if path else Path(os.getcwd())
# Resolve search path — constrained to cwd unless unrestricted
cwd = Path(os.getcwd()).resolve()
if path:
candidate = Path(path)
if candidate.is_absolute():
search_path = candidate.resolve()
else:
search_path = (cwd / candidate).resolve()
# Prevent traversal outside the working directory (unless opted in)
if not self.allow_unrestricted_paths:
try:
search_path.relative_to(cwd)
except ValueError:
return (
f"Error: Path '{path}' is outside the working directory. "
"Initialize with GrepTool(allow_unrestricted_paths=True) to allow this."
)
else:
search_path = cwd
if not search_path.exists():
return f"Error: Path '{search_path}' does not exist."
# Compile regex
# Compile regex with length guard to mitigate ReDoS
if len(pattern) > MAX_REGEX_LENGTH:
return f"Error: Pattern too long ({len(pattern)} chars). Maximum is {MAX_REGEX_LENGTH}."
flags = re.IGNORECASE if case_insensitive else 0
try:
compiled = re.compile(pattern, flags)
@@ -173,6 +208,28 @@ class GrepTool(BaseTool):
return output
@staticmethod
def _expand_brace_pattern(pattern: str) -> list[str]:
"""Expand a simple brace pattern into individual globs.
Handles a single level of brace expansion, e.g.
``*.{py,txt}`` -> ``['*.py', '*.txt']``.
Nested braces are *not* supported and the pattern is returned as-is.
Args:
pattern: Glob pattern that may contain ``{a,b,...}`` syntax.
Returns:
List of expanded patterns (or the original if no braces found).
"""
match = re.search(r"\{([^{}]+)\}", pattern)
if not match:
return [pattern]
prefix = pattern[: match.start()]
suffix = pattern[match.end() :]
alternatives = match.group(1).split(",")
return [f"{prefix}{alt.strip()}{suffix}" for alt in alternatives]
def _collect_files(self, search_path: Path, glob_pattern: str | None) -> list[Path]:
"""Collect files to search.
@@ -186,11 +243,15 @@ class GrepTool(BaseTool):
if search_path.is_file():
return [search_path]
pattern = glob_pattern or "*"
patterns = self._expand_brace_pattern(glob_pattern) if glob_pattern else ["*"]
seen: set[Path] = set()
files: list[Path] = []
for p in search_path.rglob(pattern):
for p in chain.from_iterable(search_path.rglob(pat) for pat in patterns):
if not p.is_file():
continue
if p in seen:
continue
seen.add(p)
# Skip hidden/build directories
if any(part in SKIP_DIRS for part in p.relative_to(search_path).parts):
continue
@@ -200,6 +261,37 @@ class GrepTool(BaseTool):
return sorted(files)
@staticmethod
def _safe_search(compiled_pattern: re.Pattern[str], line: str) -> re.Match[str] | None:
"""Run a regex search with a per-line timeout to mitigate ReDoS.
On platforms that support SIGALRM (Unix), a timeout is enforced.
On Windows, the search runs without a timeout but is still bounded
by MAX_LINE_LENGTH truncation applied earlier in the pipeline.
Args:
compiled_pattern: Compiled regex pattern.
line: The text line to search.
Returns:
Match object if found, None otherwise (including on timeout).
"""
if sys.platform == "win32":
return compiled_pattern.search(line)
def _timeout_handler(signum: int, frame: object) -> None:
raise TimeoutError
old_handler = signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(REGEX_MATCH_TIMEOUT_SECONDS)
try:
return compiled_pattern.search(line)
except TimeoutError:
return None
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, old_handler)
def _is_binary_file(self, file_path: Path) -> bool:
"""Check if a file is binary by looking for null bytes.
@@ -244,7 +336,7 @@ class GrepTool(BaseTool):
# Find matching line numbers
match_line_nums: list[int] = []
for i, line in enumerate(lines):
if compiled_pattern.search(line):
if self._safe_search(compiled_pattern, line):
match_line_nums.append(i)
if len(match_line_nums) >= MAX_MATCHES_PER_FILE:
break

View File

@@ -1,10 +1,13 @@
"""Unit tests for GrepTool."""
from __future__ import annotations
from pathlib import Path
import pytest
from crewai_tools import GrepTool
from crewai_tools.tools.grep_tool.grep_tool import MAX_REGEX_LENGTH
@pytest.fixture
@@ -64,8 +67,13 @@ class TestGrepTool:
"""Tests for GrepTool."""
def setup_method(self) -> None:
"""Set up test fixtures."""
self.tool = GrepTool()
"""Set up test fixtures.
We use allow_unrestricted_paths=True so that tests using pytest's
tmp_path (which lives outside the working directory) are not rejected
by the path-restriction guard.
"""
self.tool = GrepTool(allow_unrestricted_paths=True)
def test_tool_metadata(self) -> None:
"""Test tool has correct name and description."""
@@ -199,8 +207,8 @@ class TestGrepTool:
assert "helper" not in result
def test_path_not_found(self) -> None:
"""Test error message when path doesn't exist."""
result = self.tool._run(pattern="test", path="/nonexistent/path")
"""Test error message when a relative path doesn't exist."""
result = self.tool._run(pattern="test", path="totally_nonexistent_subdir")
assert "Error" in result
assert "does not exist" in result
@@ -241,3 +249,142 @@ class TestGrepTool:
pattern="Hello", path=str(sample_dir), extra_arg="ignored"
)
assert "Hello" in result
class TestPathRestriction:
"""Tests for path traversal prevention and allow_unrestricted_paths."""
def test_absolute_path_outside_cwd_blocked(self, tmp_path: Path) -> None:
"""An absolute path outside cwd is rejected by default."""
tool = GrepTool()
# tmp_path is almost certainly not under os.getcwd()
result = tool._run(pattern="anything", path=str(tmp_path))
assert "Error" in result
assert "outside the working directory" in result
def test_relative_traversal_blocked(self, sample_dir: Path) -> None:
"""A relative path with ../ that escapes cwd is rejected."""
tool = GrepTool()
result = tool._run(pattern="anything", path="../../etc")
assert "Error" in result
assert "outside the working directory" in result
def test_relative_path_within_cwd_allowed(self) -> None:
"""A relative path that stays inside cwd works fine."""
tool = GrepTool()
# "." is always within cwd
result = tool._run(pattern="zzz_will_not_match_anything_zzz", path=".")
# Should not get a traversal error — either matches or "No matches found"
assert "outside the working directory" not in result
def test_allow_unrestricted_paths_bypasses_check(self, tmp_path: Path) -> None:
"""With allow_unrestricted_paths=True, absolute paths outside cwd are allowed."""
# Write a searchable file in tmp_path
(tmp_path / "hello.txt").write_text("unrestricted search target\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(pattern="unrestricted", path=str(tmp_path))
assert "unrestricted search target" in result
def test_allow_unrestricted_defaults_false(self) -> None:
"""The flag defaults to False."""
tool = GrepTool()
assert tool.allow_unrestricted_paths is False
def test_error_message_includes_hint(self, tmp_path: Path) -> None:
"""The traversal error tells the user how to opt in."""
tool = GrepTool()
result = tool._run(pattern="x", path=str(tmp_path))
assert "GrepTool(allow_unrestricted_paths=True)" in result
class TestReDoSGuards:
"""Tests for regex denial-of-service mitigations."""
def test_pattern_length_rejected(self, sample_dir: Path) -> None:
"""Patterns exceeding MAX_REGEX_LENGTH are rejected before compilation."""
tool = GrepTool(allow_unrestricted_paths=True)
long_pattern = "a" * (MAX_REGEX_LENGTH + 1)
result = tool._run(pattern=long_pattern, path=str(sample_dir))
assert "Error" in result
assert "Pattern too long" in result
def test_pattern_at_max_length_accepted(self, sample_dir: Path) -> None:
"""A pattern exactly at MAX_REGEX_LENGTH is allowed (boundary check)."""
tool = GrepTool(allow_unrestricted_paths=True)
exact_pattern = "a" * MAX_REGEX_LENGTH
result = tool._run(pattern=exact_pattern, path=str(sample_dir))
# Should not get a length error — either matches or "No matches found"
assert "Pattern too long" not in result
def test_safe_search_returns_match(self) -> None:
"""_safe_search returns a match object for a normal pattern."""
compiled = __import__("re").compile(r"hello")
match = GrepTool._safe_search(compiled, "say hello world")
assert match is not None
assert match.group() == "hello"
def test_safe_search_returns_none_on_no_match(self) -> None:
"""_safe_search returns None when the pattern doesn't match."""
compiled = __import__("re").compile(r"zzz")
match = GrepTool._safe_search(compiled, "hello world")
assert match is None
class TestBraceExpansion:
"""Tests for glob brace expansion ({a,b} syntax)."""
def test_expand_simple_brace(self) -> None:
"""*.{py,txt} expands to ['*.py', '*.txt']."""
result = GrepTool._expand_brace_pattern("*.{py,txt}")
assert result == ["*.py", "*.txt"]
def test_expand_three_alternatives(self) -> None:
"""*.{py,txt,md} expands to three patterns."""
result = GrepTool._expand_brace_pattern("*.{py,txt,md}")
assert result == ["*.py", "*.txt", "*.md"]
def test_expand_no_braces_passthrough(self) -> None:
"""A pattern without braces is returned as a single-element list."""
result = GrepTool._expand_brace_pattern("*.py")
assert result == ["*.py"]
def test_expand_strips_whitespace(self) -> None:
"""Whitespace around alternatives inside braces is stripped."""
result = GrepTool._expand_brace_pattern("*.{ py , txt }")
assert result == ["*.py", "*.txt"]
def test_expand_prefix_and_suffix(self) -> None:
"""Prefix and suffix around the braces are preserved."""
result = GrepTool._expand_brace_pattern("src/*.{py,pyi}.bak")
assert result == ["src/*.py.bak", "src/*.pyi.bak"]
def test_brace_glob_end_to_end(self, tmp_path: Path) -> None:
"""Brace expansion works end-to-end with _collect_files."""
(tmp_path / "a.py").write_text("match_me\n")
(tmp_path / "b.txt").write_text("match_me\n")
(tmp_path / "c.md").write_text("match_me\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(
pattern="match_me",
path=str(tmp_path),
glob_pattern="*.{py,txt}",
)
assert "a.py" in result
assert "b.txt" in result
# .md should NOT be included
assert "c.md" not in result
def test_brace_glob_no_duplicates(self, tmp_path: Path) -> None:
"""Files are not reported twice when they match multiple expanded patterns."""
(tmp_path / "x.py").write_text("unique_content\n")
tool = GrepTool(allow_unrestricted_paths=True)
result = tool._run(
pattern="unique_content",
path=str(tmp_path),
glob_pattern="*.{py,py}",
output_mode="count",
)
# Should appear exactly once
assert result.count("x.py") == 1