From 0ed683241d10d7e188d7ccdef2c8482f5af035df Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Tue, 22 Apr 2025 12:32:12 -0300 Subject: [PATCH] feat: allow to set unsafe_mode from Guardrail task --- src/crewai/tasks/guardrail_task.py | 10 +++++++-- tests/test_task_guardrails.py | 34 +++++++++++++++++++++++------- 2 files changed, 34 insertions(+), 10 deletions(-) diff --git a/src/crewai/tasks/guardrail_task.py b/src/crewai/tasks/guardrail_task.py index f801247ce..d46a1d5e1 100644 --- a/src/crewai/tasks/guardrail_task.py +++ b/src/crewai/tasks/guardrail_task.py @@ -19,7 +19,7 @@ class GuardrailTask: task (Task, optional): The task whose output needs validation. llm (LLM, optional): The language model to use for code generation. additional_instructions (str, optional): Additional instructions for the guardrail task. - + unsafe_mode (bool, optional): Whether to run the code in unsafe mode. Raises: ValueError: If no valid LLM is provided. """ @@ -30,6 +30,7 @@ class GuardrailTask: task: Task | None = None, llm: LLM | None = None, additional_instructions: str = "", + unsafe_mode: bool | None = None, ): self.description = description @@ -44,6 +45,7 @@ class GuardrailTask: self.llm: LLM | None = llm or fallback_llm self.additional_instructions = additional_instructions + self.unsafe_mode = unsafe_mode @property def system_instructions(self) -> str: @@ -138,7 +140,11 @@ class GuardrailTask: code = self.generate_code(task_output) - unsafe_mode = not self.check_docker_available() + unsafe_mode = ( + self.unsafe_mode + if self.unsafe_mode is not None + else not self.check_docker_available() + ) result = CodeInterpreterTool(code=code, unsafe_mode=unsafe_mode).run() diff --git a/tests/test_task_guardrails.py b/tests/test_task_guardrails.py index 7fae055b7..c17503199 100644 --- a/tests/test_task_guardrails.py +++ b/tests/test_task_guardrails.py @@ -334,14 +334,12 @@ def test_guardrail_task_when_docker_is_not_available(mock_llm, task_output): ) as mock_init, patch( "crewai_tools.CodeInterpreterTool.run", return_value=(True, "Valid output") - ) as mock_run, + ), patch( "subprocess.run", side_effect=FileNotFoundError, ), ): - mock_init.return_value = None - mock_run.return_value = (True, "Valid output") guardrail(task_output) mock_init.assert_called_once_with(code=ANY, unsafe_mode=True) @@ -361,8 +359,6 @@ def test_guardrail_task_when_docker_is_available(mock_llm, task_output): return_value=True, ), ): - mock_init.return_value = None - mock_run.return_value = (True, "Valid output") guardrail(task_output) mock_init.assert_called_once_with(code=ANY, unsafe_mode=False) @@ -380,10 +376,32 @@ def test_guardrail_task_when_tool_output_is_not_valid(mock_llm, task_output): patch( "subprocess.run", return_value=True, - ), + ) as docker_check, ): - mock_init.return_value = None - mock_run.return_value = (True, "Valid output") guardrail(task_output) mock_init.assert_called_once_with(code=ANY, unsafe_mode=False) + docker_check.assert_called_once() + + +@pytest.mark.parametrize("unsafe_mode", [True, False]) +def test_guardrail_task_force_code_tool_unsafe_mode(mock_llm, task_output, unsafe_mode): + guardrail = GuardrailTask( + description="Test validation", llm=mock_llm, unsafe_mode=unsafe_mode + ) + with ( + patch( + "crewai_tools.CodeInterpreterTool.__init__", return_value=None + ) as mock_init, + patch( + "crewai_tools.CodeInterpreterTool.run", return_value=(True, "Valid output") + ), + patch( + "subprocess.run", + side_effect=FileNotFoundError, + ) as docker_check, + ): + guardrail(task_output) + + docker_check.assert_not_called() + mock_init.assert_called_once_with(code=ANY, unsafe_mode=unsafe_mode)