diff --git a/src/crewai/task.py b/src/crewai/task.py index b9e341e33..aa32049be 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -172,18 +172,29 @@ class Task(BaseModel): """ if v is not None: sig = inspect.signature(v) - if len(sig.parameters) != 1: - raise ValueError("Guardrail function must accept exactly one parameter") + # Get required parameters (excluding those with defaults) + required_params = [ + param for param in sig.parameters.values() + if param.default == inspect.Parameter.empty + ] + if len(required_params) != 1: + raise ValueError("Guardrail function must accept exactly one required parameter") # Check return annotation if present, but don't require it return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: - if not ( - return_annotation == Tuple[bool, Any] - or str(return_annotation) == "Tuple[bool, Any]" - ): + # Convert annotation to string for comparison + annotation_str = str(return_annotation).lower() + valid_patterns = [ + 'tuple[bool, any]', + 'typing.tuple[bool, any]', + 'tuple[bool, str]', + 'tuple[bool, taskoutput]' + ] + if not any(pattern in annotation_str for pattern in valid_patterns): raise ValueError( - "If return type is annotated, it must be Tuple[bool, Any]" + "Return type must be tuple[bool, Any] or a specific type like " + "tuple[bool, str] or tuple[bool, TaskOutput]" ) return v diff --git a/tests/test_task_guardrails.py b/tests/test_task_guardrails.py index e22e76234..73ed66389 100644 --- a/tests/test_task_guardrails.py +++ b/tests/test_task_guardrails.py @@ -127,3 +127,45 @@ def test_guardrail_error_in_context(): assert "Task failed guardrail validation" in str(exc_info.value) assert "Expected JSON, got string" in str(exc_info.value) + + +def test_guardrail_with_new_style_annotation(): + """Test guardrail with new style tuple annotation.""" + def guardrail(result: TaskOutput) -> tuple[bool, str]: + return (True, result.raw.upper()) + + agent = Mock() + agent.role = "test_agent" + agent.execute_task.return_value = "test result" + agent.crew = None + + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT" + + +def test_guardrail_with_optional_params(): + """Test guardrail with optional parameters.""" + def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]: + return (True, f"{result.raw}-{optional_param}") + + agent = Mock() + agent.role = "test_agent" + agent.execute_task.return_value = "test" + agent.crew = None + + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) + + result = task.execute_sync(agent=agent) + assert isinstance(result, TaskOutput) + assert result.raw == "test-default"