mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: make directory relative to base_dir for better UX
When base_dir is set, the directory arg is now treated as a subdirectory of base_dir rather than an absolute path. This means the LLM only needs to specify a filename (and optionally a relative subdirectory) — it does not need to repeat the base_dir path. FileWriterTool(base_dir="./output") → filename="report.txt" writes to ./output/report.txt → filename="f.txt", directory="sub" writes to ./output/sub/f.txt Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -47,8 +47,9 @@ class FileWriterTool(BaseTool):
|
|||||||
if base_dir is not None:
|
if base_dir is not None:
|
||||||
self.description = (
|
self.description = (
|
||||||
f"A tool to write files within {base_dir}. "
|
f"A tool to write files within {base_dir}. "
|
||||||
"Accepts filename, content, and optionally a directory path and overwrite flag. "
|
"Accepts filename and content. "
|
||||||
"Paths outside the allowed directory will be rejected."
|
"Optionally accepts a subdirectory (relative to the base directory) and overwrite flag. "
|
||||||
|
"Paths outside the base directory will be rejected."
|
||||||
)
|
)
|
||||||
self._generate_description()
|
self._generate_description()
|
||||||
|
|
||||||
@@ -63,21 +64,21 @@ class FileWriterTool(BaseTool):
|
|||||||
|
|
||||||
def _run(self, **kwargs: Any) -> str:
|
def _run(self, **kwargs: Any) -> str:
|
||||||
try:
|
try:
|
||||||
directory = kwargs.get("directory") or "./"
|
directory = kwargs.get("directory") or ""
|
||||||
filename = kwargs["filename"]
|
filename = kwargs["filename"]
|
||||||
|
|
||||||
filepath = os.path.join(directory, filename)
|
# When base_dir is set, directory is relative to it.
|
||||||
|
# When not set, directory is used as-is (absolute or cwd-relative).
|
||||||
|
if self.base_dir is not None:
|
||||||
|
filepath = os.path.join(self.base_dir, directory, filename)
|
||||||
|
else:
|
||||||
|
filepath = os.path.join(directory or "./", filename)
|
||||||
|
|
||||||
validated = self._validate_path(filepath)
|
validated = self._validate_path(filepath)
|
||||||
if validated is None:
|
if validated is None:
|
||||||
return f"Error: Access denied — path resolves outside the allowed directory."
|
return "Error: Access denied — path resolves outside the allowed directory."
|
||||||
|
|
||||||
# Only create directories that are within the validated path
|
|
||||||
validated_dir = os.path.dirname(validated)
|
validated_dir = os.path.dirname(validated)
|
||||||
if self.base_dir is not None:
|
|
||||||
# Ensure the directory itself is also within base_dir
|
|
||||||
if not validated_dir.startswith(self.base_dir):
|
|
||||||
return f"Error: Access denied — directory resolves outside the allowed directory."
|
|
||||||
os.makedirs(validated_dir, exist_ok=True)
|
os.makedirs(validated_dir, exist_ok=True)
|
||||||
|
|
||||||
kwargs["overwrite"] = strtobool(kwargs["overwrite"])
|
kwargs["overwrite"] = strtobool(kwargs["overwrite"])
|
||||||
|
|||||||
@@ -145,9 +145,9 @@ def scoped_tool(temp_env):
|
|||||||
|
|
||||||
|
|
||||||
def test_base_dir_allows_write_inside(scoped_tool, temp_env):
|
def test_base_dir_allows_write_inside(scoped_tool, temp_env):
|
||||||
|
"""No directory arg — writes directly into base_dir."""
|
||||||
result = scoped_tool._run(
|
result = scoped_tool._run(
|
||||||
filename=temp_env["test_file"],
|
filename=temp_env["test_file"],
|
||||||
directory=temp_env["temp_dir"],
|
|
||||||
content=temp_env["test_content"],
|
content=temp_env["test_content"],
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
@@ -155,10 +155,21 @@ def test_base_dir_allows_write_inside(scoped_tool, temp_env):
|
|||||||
assert read_file(get_test_path(temp_env["test_file"], temp_env["temp_dir"])) == temp_env["test_content"]
|
assert read_file(get_test_path(temp_env["test_file"], temp_env["temp_dir"])) == temp_env["test_content"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_base_dir_allows_relative_subdir(scoped_tool, temp_env):
|
||||||
|
"""directory arg is treated as a subdirectory of base_dir."""
|
||||||
|
result = scoped_tool._run(
|
||||||
|
filename="file.txt",
|
||||||
|
directory="subdir",
|
||||||
|
content="nested content",
|
||||||
|
overwrite=True,
|
||||||
|
)
|
||||||
|
assert "successfully written" in result
|
||||||
|
assert os.path.exists(os.path.join(temp_env["temp_dir"], "subdir", "file.txt"))
|
||||||
|
|
||||||
|
|
||||||
def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
|
def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
|
||||||
result = scoped_tool._run(
|
result = scoped_tool._run(
|
||||||
filename="../outside.txt",
|
filename="../outside.txt",
|
||||||
directory=temp_env["temp_dir"],
|
|
||||||
content="should not be written",
|
content="should not be written",
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
@@ -168,14 +179,14 @@ def test_base_dir_blocks_traversal_in_filename(scoped_tool, temp_env):
|
|||||||
def test_base_dir_blocks_traversal_in_directory(scoped_tool, temp_env):
|
def test_base_dir_blocks_traversal_in_directory(scoped_tool, temp_env):
|
||||||
result = scoped_tool._run(
|
result = scoped_tool._run(
|
||||||
filename="pwned.txt",
|
filename="pwned.txt",
|
||||||
directory=os.path.join(temp_env["temp_dir"], "../../etc/cron.d"),
|
directory="../../etc/cron.d",
|
||||||
content="should not be written",
|
content="should not be written",
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
assert "Access denied" in result
|
assert "Access denied" in result
|
||||||
|
|
||||||
|
|
||||||
def test_base_dir_blocks_absolute_path_outside(scoped_tool, temp_env):
|
def test_base_dir_blocks_absolute_directory(scoped_tool, temp_env):
|
||||||
result = scoped_tool._run(
|
result = scoped_tool._run(
|
||||||
filename="passwd",
|
filename="passwd",
|
||||||
directory="/etc",
|
directory="/etc",
|
||||||
@@ -190,23 +201,13 @@ def test_base_dir_blocks_symlink_escape(scoped_tool, temp_env):
|
|||||||
os.symlink("/etc", link)
|
os.symlink("/etc", link)
|
||||||
result = scoped_tool._run(
|
result = scoped_tool._run(
|
||||||
filename="crontab",
|
filename="crontab",
|
||||||
directory=link,
|
directory="escape",
|
||||||
content="should not be written",
|
content="should not be written",
|
||||||
overwrite=True,
|
overwrite=True,
|
||||||
)
|
)
|
||||||
assert "Access denied" in result
|
assert "Access denied" in result
|
||||||
|
|
||||||
|
|
||||||
def test_base_dir_allows_nested_subdir(scoped_tool, temp_env):
|
|
||||||
result = scoped_tool._run(
|
|
||||||
filename="file.txt",
|
|
||||||
directory=os.path.join(temp_env["temp_dir"], "subdir"),
|
|
||||||
content="nested content",
|
|
||||||
overwrite=True,
|
|
||||||
)
|
|
||||||
assert "successfully written" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_base_dir_description_mentions_directory(temp_env):
|
def test_base_dir_description_mentions_directory(temp_env):
|
||||||
tool = FileWriterTool(base_dir=temp_env["temp_dir"])
|
tool = FileWriterTool(base_dir=temp_env["temp_dir"])
|
||||||
assert temp_env["temp_dir"] in tool.description
|
assert temp_env["temp_dir"] in tool.description
|
||||||
|
|||||||
Reference in New Issue
Block a user