Fix: union type support

This commit is contained in:
elda27
2025-02-21 01:27:32 +09:00
parent a5f862529d
commit 41961ca749
2 changed files with 7 additions and 4 deletions

View File

@@ -19,6 +19,8 @@ from typing import (
Tuple,
Type,
Union,
get_args,
get_origin,
)
from pydantic import (
@@ -183,7 +185,6 @@ class Task(BaseModel):
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
from typing import get_args, get_origin
return_annotation_args = get_args(return_annotation)
if not (
@@ -194,6 +195,7 @@ class Task(BaseModel):
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput]
)
):
raise ValueError(

View File

@@ -4,6 +4,7 @@ import hashlib
import json
import os
from functools import partial
from typing import Tuple, Union
from unittest.mock import MagicMock, patch
import pytest
@@ -241,11 +242,11 @@ def test_guardrail_type_error():
return (True, x)
@classmethod
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, TaskOutput]:
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]:
return (True, x)
@staticmethod
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
return (True, x)
obj = Object()
@@ -268,7 +269,7 @@ def test_guardrail_type_error():
guardrail=Object.guardrail_static_fn,
)
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
return (y, x)
Task(