From d94ea52db796f141fea30f052b2c49e17753321d Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Tue, 25 Feb 2025 16:42:47 -0500 Subject: [PATCH] improve task guardrails to address issue #2177 --- src/crewai/task.py | 28 +++++++++--- tests/task_test.py | 106 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 128 insertions(+), 6 deletions(-) diff --git a/src/crewai/task.py b/src/crewai/task.py index b9e341e33..748e401e4 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -19,6 +19,8 @@ from typing import ( Tuple, Type, Union, + get_args, + get_origin, ) from pydantic import ( @@ -172,15 +174,29 @@ class Task(BaseModel): """ if v is not None: sig = inspect.signature(v) - if len(sig.parameters) != 1: + positional_args = [ + param + for param in sig.parameters.values() + if param.default is inspect.Parameter.empty + ] + if len(positional_args) != 1: raise ValueError("Guardrail function must accept exactly one parameter") # Check return annotation if present, but don't require it return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: + + return_annotation_args = get_args(return_annotation) if not ( - return_annotation == Tuple[bool, Any] - or str(return_annotation) == "Tuple[bool, Any]" + get_origin(return_annotation) is tuple + and len(return_annotation_args) == 2 + and return_annotation_args[0] is bool + and ( + return_annotation_args[1] is Any + or return_annotation_args[1] is str + or return_annotation_args[1] is TaskOutput + or return_annotation_args[1] == Union[str, TaskOutput] + ) ): raise ValueError( "If return type is annotated, it must be Tuple[bool, Any]" @@ -435,9 +451,9 @@ class Task(BaseModel): content = ( json_output if json_output - else pydantic_output.model_dump_json() - if pydantic_output - else result + else ( + pydantic_output.model_dump_json() if pydantic_output else result + ) ) self._save_file(content) crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output)) diff --git a/tests/task_test.py b/tests/task_test.py index 3cd11cfc7..897f1e2dd 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -1283,3 +1283,109 @@ def test_interpolate_valid_types(): assert parsed["optional"] is None assert parsed["nested"]["flag"] is True assert parsed["nested"]["empty"] is None + + +def test_guardrail_with_new_style_annotations(): + """Test that guardrails with new-style type annotations work correctly.""" + + # Define a guardrail with new-style annotation + def guardrail(result: TaskOutput) -> tuple[bool, str]: + return (True, result.raw.upper()) + + agent = MagicMock() + agent.role = "test_agent" + agent.execute_task.return_value = "test result" + agent.crew = None + + task = Task(description="Test task", expected_output="Output", guardrail=guardrail) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT" + + +def test_guardrail_with_specific_return_type(): + """Test that guardrails with specific return types work correctly.""" + + # Define a guardrail with specific return type + def guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]: + if "error" in result.raw.lower(): + return (False, "Contains error") + return (True, result) + + agent = MagicMock() + agent.role = "test_agent" + agent.execute_task.return_value = "success result" + agent.crew = None + + task = Task(description="Test task", expected_output="Output", guardrail=guardrail) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "success result" + + +def test_guardrail_with_positional_and_default_args(): + """Test that guardrails with positional and default arguments work correctly.""" + + # Define a guardrail with a positional argument and a default argument + def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]: + return (True, result.raw.upper()) + + agent = MagicMock() + agent.role = "test_agent" + agent.execute_task.return_value = "test result" + agent.crew = None + + # This should now work with the updated validator + task = Task(description="Test task", expected_output="Output", guardrail=guardrail) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT" + + +def test_guardrail_with_multiple_positional_args(): + """Test that guardrails with multiple positional arguments are rejected.""" + + # Define a guardrail with multiple positional arguments + def guardrail(result: TaskOutput, another_required_arg) -> tuple[bool, str]: + return (True, result.raw.upper()) + + agent = MagicMock() + agent.role = "test_agent" + agent.execute_task.return_value = "test result" + agent.crew = None + + # This should raise a ValueError because guardrail must accept exactly one positional parameter + with pytest.raises(ValueError) as excinfo: + Task(description="Test task", expected_output="Output", guardrail=guardrail) + + assert "Guardrail function must accept exactly one parameter" in str(excinfo.value) + + +def test_guardrail_with_positional_and_default_args(): + """Validate that the guardrail function has the correct signature and behavior. + + While type hints provide static checking, this validator ensures runtime safety by: + 1. Verifying the function accepts exactly one required parameter (the TaskOutput) + (additional parameters with default values are allowed) + 2. Checking return type annotations match Tuple[bool, Any] or tuple[bool, Any] if present + 3. Providing clear, immediate error messages for debugging + """ + + # Define a guardrail with a positional argument and a default argument + def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]: + return (True, result.raw.upper()) + + agent = MagicMock() + agent.role = "test_agent" + agent.execute_task.return_value = "test result" + agent.crew = None + + # This should now work with the updated validator + task = Task(description="Test task", expected_output="Output", guardrail=guardrail) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT"