fix: remove directory field from LLM schema when base_dir is set

When a developer sets base_dir, they control where files are written.
The LLM should only supply filename and content — not a directory path.

Adds ScopedFileWriterToolInput (no directory field) which is used when
base_dir is provided at construction, following the same pattern as
FileReadTool/ScrapeWebsiteTool.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
Rip&Tear
2026-03-16 00:05:37 +08:00
parent 438dee3783
commit 61d02d7296
2 changed files with 26 additions and 37 deletions

View File

@@ -23,6 +23,13 @@ class FileWriterToolInput(BaseModel):
content: str
class ScopedFileWriterToolInput(BaseModel):
"""Input when base_dir is set — the LLM supplies only filename and content."""
filename: str
overwrite: str | bool = False
content: str
class FileWriterTool(BaseTool):
name: str = "File Writer Tool"
description: str = (
@@ -45,11 +52,10 @@ class FileWriterTool(BaseTool):
super().__init__(**kwargs)
self.base_dir = os.path.realpath(base_dir) if base_dir is not None else None
if base_dir is not None:
self.args_schema = ScopedFileWriterToolInput
self.description = (
f"A tool to write files within {base_dir}. "
"Accepts filename and content. "
"Optionally accepts a subdirectory (relative to the base directory) and overwrite flag. "
"Paths outside the base directory will be rejected."
f"A tool to write files into {base_dir}. "
"Accepts a filename, content, and an optional overwrite flag."
)
self._generate_description()
@@ -64,15 +70,14 @@ class FileWriterTool(BaseTool):
def _run(self, **kwargs: Any) -> str:
try:
directory = kwargs.get("directory") or ""
filename = kwargs["filename"]
# When base_dir is set, directory is relative to it.
# When not set, directory is used as-is (absolute or cwd-relative).
if self.base_dir is not None:
filepath = os.path.join(self.base_dir, directory, filename)
# Developer controls the directory; LLM only supplies filename.
filepath = os.path.join(self.base_dir, filename)
else:
filepath = os.path.join(directory or "./", filename)
directory = kwargs.get("directory") or "./"
filepath = os.path.join(directory, filename)
validated = self._validate_path(filepath)
if validated is None:

View File

@@ -144,8 +144,16 @@ def scoped_tool(temp_env):
return FileWriterTool(base_dir=temp_env["temp_dir"])
def test_base_dir_schema_has_no_directory_field(temp_env):
"""When base_dir is set, the LLM schema has no directory field."""
from crewai_tools.tools.file_writer_tool.file_writer_tool import ScopedFileWriterToolInput
tool = FileWriterTool(base_dir=temp_env["temp_dir"])
assert tool.args_schema is ScopedFileWriterToolInput
assert "directory" not in tool.args_schema.model_fields
def test_base_dir_allows_write_inside(scoped_tool, temp_env):
"""No directory arg — writes directly into base_dir."""
"""LLM supplies only filename — file lands in base_dir."""
result = scoped_tool._run(
filename=temp_env["test_file"],
content=temp_env["test_content"],
@@ -155,18 +163,6 @@ def test_base_dir_allows_write_inside(scoped_tool, temp_env):
assert read_file(get_test_path(temp_env["test_file"], temp_env["temp_dir"])) == temp_env["test_content"]
def test_base_dir_allows_relative_subdir(scoped_tool, temp_env):
"""directory arg is treated as a subdirectory of base_dir."""
result = scoped_tool._run(
filename="file.txt",
directory="subdir",
content="nested content",
overwrite=True,
)
assert "successfully written" in result
assert os.path.exists(os.path.join(temp_env["temp_dir"], "subdir", "file.txt"))
def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
result = scoped_tool._run(
filename="../outside.txt",
@@ -176,20 +172,9 @@ def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
assert "Access denied" in result
def test_base_dir_blocks_traversal_in_directory(scoped_tool, temp_env):
def test_base_dir_blocks_absolute_filename(scoped_tool, temp_env):
result = scoped_tool._run(
filename="pwned.txt",
directory="../../etc/cron.d",
content="should not be written",
overwrite=True,
)
assert "Access denied" in result
def test_base_dir_blocks_absolute_directory(scoped_tool, temp_env):
result = scoped_tool._run(
filename="passwd",
directory="/etc",
filename="/etc/passwd",
content="should not be written",
overwrite=True,
)
@@ -200,8 +185,7 @@ def test_base_dir_blocks_symlink_escape(scoped_tool, temp_env):
link = os.path.join(temp_env["temp_dir"], "escape")
os.symlink("/etc", link)
result = scoped_tool._run(
filename="crontab",
directory="escape",
filename="escape/crontab",
content="should not be written",
overwrite=True,
)