fix: add base_dir path containment to FileWriterTool

os.path.join does not prevent traversal — joining "./" with "../../../etc/cron.d/pwned"
resolves cleanly outside any intended scope. The tool also called os.makedirs on
the unvalidated path, meaning it would create arbitrary directory structures.

Adds a base_dir parameter that uses os.path.realpath() to resolve the final path
(including symlinks) before checking containment. Any filename or directory argument
that resolves outside base_dir is rejected before any filesystem operation occurs.

When base_dir is not set the tool behaves as before — only use that in fully
sandboxed environments.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Rip&Tear
2026-03-15 23:51:01 +08:00
parent fb2323b3de
commit 6fc0914f26
2 changed files with 130 additions and 16 deletions

View File

@@ -2,7 +2,7 @@ import os
from typing import Any
from crewai.tools import BaseTool
from pydantic import BaseModel
from pydantic import BaseModel, Field
def strtobool(val) -> bool:
@@ -25,33 +25,72 @@ class FileWriterToolInput(BaseModel):
class FileWriterTool(BaseTool):
name: str = "File Writer Tool"
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
description: str = (
"A tool to write content to a specified file. "
"Accepts filename, content, and optionally a directory path and overwrite flag as input."
)
args_schema: type[BaseModel] = FileWriterToolInput
base_dir: str | None = None
def __init__(self, base_dir: str | None = None, **kwargs: Any) -> None:
"""Initialize the FileWriterTool.
Args:
base_dir (Optional[str]): Restrict all writes to this directory tree.
Any filename or directory that resolves outside base_dir is rejected,
including ../traversal and symlink escapes. When not set the tool
can write anywhere the process has permission to — only use that in
fully sandboxed environments.
"""
super().__init__(**kwargs)
self.base_dir = os.path.realpath(base_dir) if base_dir is not None else None
if base_dir is not None:
self.description = (
f"A tool to write files within {base_dir}. "
"Accepts filename, content, and optionally a directory path and overwrite flag. "
"Paths outside the allowed directory will be rejected."
)
self._generate_description()
def _validate_path(self, filepath: str) -> str | None:
"""Resolve path and enforce base_dir containment. Returns None if rejected."""
if self.base_dir is None:
return filepath
real_path = os.path.realpath(filepath)
if not real_path.startswith(self.base_dir + os.sep) and real_path != self.base_dir:
return None
return real_path
def _run(self, **kwargs: Any) -> str:
try:
if kwargs.get("directory"):
os.makedirs(kwargs["directory"], exist_ok=True)
directory = kwargs.get("directory") or "./"
filename = kwargs["filename"]
# Construct the full path
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])
filepath = os.path.join(directory, filename)
validated = self._validate_path(filepath)
if validated is None:
return f"Error: Access denied — path resolves outside the allowed directory."
# Only create directories that are within the validated path
validated_dir = os.path.dirname(validated)
if self.base_dir is not None:
# Ensure the directory itself is also within base_dir
if not validated_dir.startswith(self.base_dir):
return f"Error: Access denied — directory resolves outside the allowed directory."
os.makedirs(validated_dir, exist_ok=True)
# Convert overwrite to boolean
kwargs["overwrite"] = strtobool(kwargs["overwrite"])
# Check if file exists and overwrite is not allowed
if os.path.exists(filepath) and not kwargs["overwrite"]:
return f"File {filepath} already exists and overwrite option was not passed."
if os.path.exists(validated) and not kwargs["overwrite"]:
return f"File {validated} already exists and overwrite option was not passed."
# Write content to the file
mode = "w" if kwargs["overwrite"] else "x"
with open(filepath, mode) as file:
with open(validated, mode) as file:
file.write(kwargs["content"])
return f"Content successfully written to {filepath}"
return f"Content successfully written to {validated}"
except FileExistsError:
return (
f"File {filepath} already exists and overwrite option was not passed."
)
return f"File already exists and overwrite option was not passed."
except KeyError as e:
return f"An error occurred while accessing key: {e!s}"
except Exception as e:

View File

@@ -135,3 +135,78 @@ def test_file_exists_error_handling(tool, temp_env, overwrite):
assert "already exists and overwrite option was not passed" in result
assert read_file(path) == "Pre-existing content"
# --- base_dir containment ---
@pytest.fixture
def scoped_tool(temp_env):
return FileWriterTool(base_dir=temp_env["temp_dir"])
def test_base_dir_allows_write_inside(scoped_tool, temp_env):
result = scoped_tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite=True,
)
assert "successfully written" in result
assert read_file(get_test_path(temp_env["test_file"], temp_env["temp_dir"])) == temp_env["test_content"]
def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
result = scoped_tool._run(
filename="../outside.txt",
directory=temp_env["temp_dir"],
content="should not be written",
overwrite=True,
)
assert "Access denied" in result
def test_base_dir_blocks_traversal_in_directory(scoped_tool, temp_env):
result = scoped_tool._run(
filename="pwned.txt",
directory=os.path.join(temp_env["temp_dir"], "../../etc/cron.d"),
content="should not be written",
overwrite=True,
)
assert "Access denied" in result
def test_base_dir_blocks_absolute_path_outside(scoped_tool, temp_env):
result = scoped_tool._run(
filename="passwd",
directory="/etc",
content="should not be written",
overwrite=True,
)
assert "Access denied" in result
def test_base_dir_blocks_symlink_escape(scoped_tool, temp_env):
link = os.path.join(temp_env["temp_dir"], "escape")
os.symlink("/etc", link)
result = scoped_tool._run(
filename="crontab",
directory=link,
content="should not be written",
overwrite=True,
)
assert "Access denied" in result
def test_base_dir_allows_nested_subdir(scoped_tool, temp_env):
result = scoped_tool._run(
filename="file.txt",
directory=os.path.join(temp_env["temp_dir"], "subdir"),
content="nested content",
overwrite=True,
)
assert "successfully written" in result
def test_base_dir_description_mentions_directory(temp_env):
tool = FileWriterTool(base_dir=temp_env["temp_dir"])
assert temp_env["temp_dir"] in tool.description