diff --git a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py index f975d3301..8b9ca5225 100644 --- a/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py +++ b/src/crewai_tools/tools/file_writer_tool/file_writer_tool.py @@ -1,21 +1,34 @@ import os -from distutils.util import strtobool from typing import Any, Optional, Type from crewai.tools import BaseTool from pydantic import BaseModel +def strtobool(val) -> bool: + if isinstance(val, bool): + return val + val = val.lower() + if val in ("y", "yes", "t", "true", "on", "1"): + return True + elif val in ("n", "no", "f", "false", "off", "0"): + return False + else: + raise ValueError(f"invalid value to cast to bool: {val!r}") + + class FileWriterToolInput(BaseModel): filename: str directory: Optional[str] = "./" - overwrite: str = "False" + overwrite: str | bool = False content: str class FileWriterTool(BaseTool): name: str = "File Writer Tool" - description: str = "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." + description: str = ( + "A tool to write content to a specified file. Accepts filename, content, and optionally a directory path and overwrite flag as input." + ) args_schema: Type[BaseModel] = FileWriterToolInput def _run(self, **kwargs: Any) -> str: @@ -28,7 +41,7 @@ class FileWriterTool(BaseTool): filepath = os.path.join(kwargs.get("directory") or "", kwargs["filename"]) # Convert overwrite to boolean - kwargs["overwrite"] = bool(strtobool(kwargs["overwrite"])) + kwargs["overwrite"] = strtobool(kwargs["overwrite"]) # Check if file exists and overwrite is not allowed if os.path.exists(filepath) and not kwargs["overwrite"]: diff --git a/src/crewai_tools/tools/file_writer_tool/tests/test_file_writer_tool.py b/src/crewai_tools/tools/file_writer_tool/tests/test_file_writer_tool.py new file mode 100644 index 000000000..d75ed30f2 --- /dev/null +++ b/src/crewai_tools/tools/file_writer_tool/tests/test_file_writer_tool.py @@ -0,0 +1,138 @@ +import os +import shutil +import tempfile + +import pytest + +from crewai_tools.tools.file_writer_tool.file_writer_tool import FileWriterTool + + +@pytest.fixture +def tool(): + return FileWriterTool() + + +@pytest.fixture +def temp_env(): + temp_dir = tempfile.mkdtemp() + test_file = "test.txt" + test_content = "Hello, World!" + + yield { + "temp_dir": temp_dir, + "test_file": test_file, + "test_content": test_content, + } + + shutil.rmtree(temp_dir, ignore_errors=True) + + +def get_test_path(filename, directory): + return os.path.join(directory, filename) + + +def read_file(path): + with open(path, "r") as f: + return f.read() + + +def test_basic_file_write(tool, temp_env): + result = tool._run( + filename=temp_env["test_file"], + directory=temp_env["temp_dir"], + content=temp_env["test_content"], + overwrite=True, + ) + + path = get_test_path(temp_env["test_file"], temp_env["temp_dir"]) + assert os.path.exists(path) + assert read_file(path) == temp_env["test_content"] + assert "successfully written" in result + + +def test_directory_creation(tool, temp_env): + new_dir = os.path.join(temp_env["temp_dir"], "nested_dir") + result = tool._run( + filename=temp_env["test_file"], + directory=new_dir, + content=temp_env["test_content"], + overwrite=True, + ) + + path = get_test_path(temp_env["test_file"], new_dir) + assert os.path.exists(new_dir) + assert os.path.exists(path) + assert "successfully written" in result + + +@pytest.mark.parametrize( + "overwrite", + ["y", "yes", "t", "true", "on", "1", True], +) +def test_overwrite_true(tool, temp_env, overwrite): + path = get_test_path(temp_env["test_file"], temp_env["temp_dir"]) + with open(path, "w") as f: + f.write("Original content") + + result = tool._run( + filename=temp_env["test_file"], + directory=temp_env["temp_dir"], + content="New content", + overwrite=overwrite, + ) + + assert read_file(path) == "New content" + assert "successfully written" in result + + +def test_invalid_overwrite_value(tool, temp_env): + result = tool._run( + filename=temp_env["test_file"], + directory=temp_env["temp_dir"], + content=temp_env["test_content"], + overwrite="invalid", + ) + assert "invalid value" in result + + +def test_missing_required_fields(tool, temp_env): + result = tool._run( + directory=temp_env["temp_dir"], + content=temp_env["test_content"], + overwrite=True, + ) + assert "An error occurred while accessing key: 'filename'" in result + + +def test_empty_content(tool, temp_env): + result = tool._run( + filename=temp_env["test_file"], + directory=temp_env["temp_dir"], + content="", + overwrite=True, + ) + + path = get_test_path(temp_env["test_file"], temp_env["temp_dir"]) + assert os.path.exists(path) + assert read_file(path) == "" + assert "successfully written" in result + + +@pytest.mark.parametrize( + "overwrite", + ["n", "no", "f", "false", "off", "0", False], +) +def test_file_exists_error_handling(tool, temp_env, overwrite): + path = get_test_path(temp_env["test_file"], temp_env["temp_dir"]) + with open(path, "w") as f: + f.write("Pre-existing content") + + result = tool._run( + filename=temp_env["test_file"], + directory=temp_env["temp_dir"], + content="Should not be written", + overwrite=overwrite, + ) + + assert "already exists and overwrite option was not passed" in result + assert read_file(path) == "Pre-existing content"