diff --git a/src/crewai/task.py b/src/crewai/task.py index be400e99a..0c063e4f9 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -19,6 +19,8 @@ from typing import ( Tuple, Type, Union, + get_args, + get_origin, ) from pydantic import ( @@ -178,15 +180,29 @@ class Task(BaseModel): """ if v is not None: sig = inspect.signature(v) - if len(sig.parameters) != 1: + positional_args = [ + param + for param in sig.parameters.values() + if param.default is inspect.Parameter.empty + ] + if len(positional_args) != 1: raise ValueError("Guardrail function must accept exactly one parameter") # Check return annotation if present, but don't require it return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: + + return_annotation_args = get_args(return_annotation) if not ( - return_annotation == Tuple[bool, Any] - or str(return_annotation) == "Tuple[bool, Any]" + get_origin(return_annotation) is tuple + and len(return_annotation_args) == 2 + and return_annotation_args[0] is bool + and ( + return_annotation_args[1] is Any + or return_annotation_args[1] is str + or return_annotation_args[1] is TaskOutput + or return_annotation_args[1] == Union[str, TaskOutput] + ) ): raise ValueError( "If return type is annotated, it must be Tuple[bool, Any]" diff --git a/tests/task_test.py b/tests/task_test.py index 3cd11cfc7..ac25a14f8 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -3,6 +3,8 @@ import hashlib import json import os +from functools import partial +from typing import Tuple, Union from unittest.mock import MagicMock, patch import pytest @@ -215,6 +217,75 @@ def test_multiple_output_type_error(): ) +def test_guardrail_type_error(): + desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting." + expected_output = "Bullet point list of 5 interesting ideas." + # Lambda function + Task( + description=desc, + expected_output=expected_output, + guardrail=lambda x: (True, x), + ) + + # Function + def guardrail_fn(x: TaskOutput) -> tuple[bool, TaskOutput]: + return (True, x) + + Task( + description=desc, + expected_output=expected_output, + guardrail=guardrail_fn, + ) + + class Object: + def guardrail_fn(self, x: TaskOutput) -> tuple[bool, TaskOutput]: + return (True, x) + + @classmethod + def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]: + return (True, x) + + @staticmethod + def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]: + return (True, x) + + obj = Object() + # Method + Task( + description=desc, + expected_output=expected_output, + guardrail=obj.guardrail_fn, + ) + # Class method + Task( + description=desc, + expected_output=expected_output, + guardrail=Object.guardrail_class_fn, + ) + # Static method + Task( + description=desc, + expected_output=expected_output, + guardrail=Object.guardrail_static_fn, + ) + + def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]: + return (y, x) + + Task( + description=desc, + expected_output=expected_output, + guardrail=partial(error_fn, y=True), + ) + + with pytest.raises(ValidationError): + Task( + description=desc, + expected_output=expected_output, + guardrail=error_fn, + ) + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_output_pydantic_sequential(): class ScoreOutput(BaseModel):