diff --git a/src/crewai/task.py b/src/crewai/task.py index 7f12bebab..748e401e4 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -19,6 +19,8 @@ from typing import ( Tuple, Type, Union, + get_args, + get_origin, ) from pydantic import ( @@ -183,7 +185,6 @@ class Task(BaseModel): # Check return annotation if present, but don't require it return_annotation = sig.return_annotation if return_annotation != inspect.Signature.empty: - from typing import get_args, get_origin return_annotation_args = get_args(return_annotation) if not ( @@ -194,6 +195,7 @@ class Task(BaseModel): return_annotation_args[1] is Any or return_annotation_args[1] is str or return_annotation_args[1] is TaskOutput + or return_annotation_args[1] == Union[str, TaskOutput] ) ): raise ValueError( diff --git a/tests/task_test.py b/tests/task_test.py index eed69859c..ac25a14f8 100644 --- a/tests/task_test.py +++ b/tests/task_test.py @@ -4,6 +4,7 @@ import hashlib import json import os from functools import partial +from typing import Tuple, Union from unittest.mock import MagicMock, patch import pytest @@ -241,11 +242,11 @@ def test_guardrail_type_error(): return (True, x) @classmethod - def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, TaskOutput]: + def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]: return (True, x) @staticmethod - def guardrail_static_fn(x: TaskOutput) -> tuple[bool, TaskOutput]: + def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]: return (True, x) obj = Object() @@ -268,7 +269,7 @@ def test_guardrail_type_error(): guardrail=Object.guardrail_static_fn, ) - def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]: + def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]: return (y, x) Task(