Fix: More comfortable validation #2177 (#2178)

* Fix: More confortable validation

* Fix: union type support

---------

Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
elda27
2025-03-20 22:28:31 +09:00
committed by GitHub
parent fe0813e831
commit 520933b4c5
2 changed files with 90 additions and 3 deletions

View File

@@ -19,6 +19,8 @@ from typing import (
Tuple, Tuple,
Type, Type,
Union, Union,
get_args,
get_origin,
) )
from pydantic import ( from pydantic import (
@@ -178,15 +180,29 @@ class Task(BaseModel):
""" """
if v is not None: if v is not None:
sig = inspect.signature(v) sig = inspect.signature(v)
if len(sig.parameters) != 1: positional_args = [
param
for param in sig.parameters.values()
if param.default is inspect.Parameter.empty
]
if len(positional_args) != 1:
raise ValueError("Guardrail function must accept exactly one parameter") raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it # Check return annotation if present, but don't require it
return_annotation = sig.return_annotation return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty: if return_annotation != inspect.Signature.empty:
return_annotation_args = get_args(return_annotation)
if not ( if not (
return_annotation == Tuple[bool, Any] get_origin(return_annotation) is tuple
or str(return_annotation) == "Tuple[bool, Any]" and len(return_annotation_args) == 2
and return_annotation_args[0] is bool
and (
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput]
)
): ):
raise ValueError( raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]" "If return type is annotated, it must be Tuple[bool, Any]"

View File

@@ -3,6 +3,8 @@
import hashlib import hashlib
import json import json
import os import os
from functools import partial
from typing import Tuple, Union
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@@ -215,6 +217,75 @@ def test_multiple_output_type_error():
) )
def test_guardrail_type_error():
desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting."
expected_output = "Bullet point list of 5 interesting ideas."
# Lambda function
Task(
description=desc,
expected_output=expected_output,
guardrail=lambda x: (True, x),
)
# Function
def guardrail_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
Task(
description=desc,
expected_output=expected_output,
guardrail=guardrail_fn,
)
class Object:
def guardrail_fn(self, x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
@classmethod
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]:
return (True, x)
@staticmethod
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
return (True, x)
obj = Object()
# Method
Task(
description=desc,
expected_output=expected_output,
guardrail=obj.guardrail_fn,
)
# Class method
Task(
description=desc,
expected_output=expected_output,
guardrail=Object.guardrail_class_fn,
)
# Static method
Task(
description=desc,
expected_output=expected_output,
guardrail=Object.guardrail_static_fn,
)
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
return (y, x)
Task(
description=desc,
expected_output=expected_output,
guardrail=partial(error_fn, y=True),
)
with pytest.raises(ValidationError):
Task(
description=desc,
expected_output=expected_output,
guardrail=error_fn,
)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_output_pydantic_sequential(): def test_output_pydantic_sequential():
class ScoreOutput(BaseModel): class ScoreOutput(BaseModel):