fix: do not use deprecated distutils in FileWriterTool (#280)

This commit is contained in:
Lucas Gomide
2025-04-23 11:29:07 -03:00
committed by GitHub
parent 7973c163f3
commit 78d0ec501d
2 changed files with 155 additions and 4 deletions

View File

@@ -1,21 +1,34 @@
import os import os
from distutils.util import strtobool
from typing import Any, Optional, Type from typing import Any, Optional, Type
from crewai.tools import BaseTool from crewai.tools import BaseTool
from pydantic import BaseModel from pydantic import BaseModel
def strtobool(val) -> bool:
if isinstance(val, bool):
return val
val = val.lower()
if val in ("y", "yes", "t", "true", "on", "1"):
return True
elif val in ("n", "no", "f", "false", "off", "0"):
return False
else:
raise ValueError(f"invalid value to cast to bool: {val!r}")
class FileWriterToolInput(BaseModel): class FileWriterToolInput(BaseModel):
filename: str filename: str
directory: Optional[str] = "./" directory: Optional[str] = "./"
overwrite: str = "False" overwrite: str | bool = False
content: str content: str
class FileWriterTool(BaseTool): class FileWriterTool(BaseTool):
name: str = "File Writer Tool" name: str = "File Writer Tool"
description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." description: str = (
"A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input."
)
args_schema: Type[BaseModel] = FileWriterToolInput args_schema: Type[BaseModel] = FileWriterToolInput
def _run(self, **kwargs: Any) -> str: def _run(self, **kwargs: Any) -> str:
@@ -28,7 +41,7 @@ class FileWriterTool(BaseTool):
filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"]) filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"])
# Convert overwrite to boolean # Convert overwrite to boolean
kwargs["overwrite"] = bool(strtobool(kwargs["overwrite"])) kwargs["overwrite"] = strtobool(kwargs["overwrite"])
# Check if file exists and overwrite is not allowed # Check if file exists and overwrite is not allowed
if os.path.exists(filepath) and not kwargs["overwrite"]: if os.path.exists(filepath) and not kwargs["overwrite"]:

View File

@@ -0,0 +1,138 @@
import os
import shutil
import tempfile
import pytest
from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool
@pytest.fixture
def tool():
return FileWriterTool()
@pytest.fixture
def temp_env():
temp_dir = tempfile.mkdtemp()
test_file = "test.txt"
test_content = "Hello, World!"
yield {
"temp_dir": temp_dir,
"test_file": test_file,
"test_content": test_content,
}
shutil.rmtree(temp_dir, ignore_errors=True)
def get_test_path(filename, directory):
return os.path.join(directory, filename)
def read_file(path):
with open(path, "r") as f:
return f.read()
def test_basic_file_write(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite=True,
)
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
assert os.path.exists(path)
assert read_file(path) == temp_env["test_content"]
assert "successfully written" in result
def test_directory_creation(tool, temp_env):
new_dir = os.path.join(temp_env["temp_dir"], "nested_dir")
result = tool._run(
filename=temp_env["test_file"],
directory=new_dir,
content=temp_env["test_content"],
overwrite=True,
)
path = get_test_path(temp_env["test_file"], new_dir)
assert os.path.exists(new_dir)
assert os.path.exists(path)
assert "successfully written" in result
@pytest.mark.parametrize(
"overwrite",
["y", "yes", "t", "true", "on", "1", True],
)
def test_overwrite_true(tool, temp_env, overwrite):
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
with open(path, "w") as f:
f.write("Original content")
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="New content",
overwrite=overwrite,
)
assert read_file(path) == "New content"
assert "successfully written" in result
def test_invalid_overwrite_value(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite="invalid",
)
assert "invalid value" in result
def test_missing_required_fields(tool, temp_env):
result = tool._run(
directory=temp_env["temp_dir"],
content=temp_env["test_content"],
overwrite=True,
)
assert "An error occurred while accessing key: 'filename'" in result
def test_empty_content(tool, temp_env):
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="",
overwrite=True,
)
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
assert os.path.exists(path)
assert read_file(path) == ""
assert "successfully written" in result
@pytest.mark.parametrize(
"overwrite",
["n", "no", "f", "false", "off", "0", False],
)
def test_file_exists_error_handling(tool, temp_env, overwrite):
path = get_test_path(temp_env["test_file"], temp_env["temp_dir"])
with open(path, "w") as f:
f.write("Pre-existing content")
result = tool._run(
filename=temp_env["test_file"],
directory=temp_env["temp_dir"],
content="Should not be written",
overwrite=overwrite,
)
assert "already exists and overwrite option was not passed" in result
assert read_file(path) == "Pre-existing content"