Fix: More confortable validation

This commit is contained in:
elda27
2025-02-21 00:34:47 +09:00
parent 00c2f5043e
commit a5f862529d
2 changed files with 90 additions and 6 deletions

View File

@@ -172,15 +172,29 @@ class Task(BaseModel):
"""
if v is not None:
sig = inspect.signature(v)
if len(sig.parameters) != 1:
positional_args = [
param
for param in sig.parameters.values()
if param.default is inspect.Parameter.empty
]
if len(positional_args) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
from typing import get_args, get_origin
return_annotation_args = get_args(return_annotation)
if not (
return_annotation == Tuple[bool, Any]
or str(return_annotation) == "Tuple[bool, Any]"
get_origin(return_annotation) is tuple
and len(return_annotation_args) == 2
and return_annotation_args[0] is bool
and (
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
)
):
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
@@ -435,9 +449,9 @@ class Task(BaseModel):
content = (
json_output
if json_output
else pydantic_output.model_dump_json()
if pydantic_output
else result
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
)
self._save_file(content)
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output))

View File

@@ -3,6 +3,7 @@
import hashlib
import json
import os
from functools import partial
from unittest.mock import MagicMock, patch
import pytest
@@ -215,6 +216,75 @@ def test_multiple_output_type_error():
)
def test_guardrail_type_error():
desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting."
expected_output = "Bullet point list of 5 interesting ideas."
# Lambda function
Task(
description=desc,
expected_output=expected_output,
guardrail=lambda x: (True, x),
)
# Function
def guardrail_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
Task(
description=desc,
expected_output=expected_output,
guardrail=guardrail_fn,
)
class Object:
def guardrail_fn(self, x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
@classmethod
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
@staticmethod
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
return (True, x)
obj = Object()
# Method
Task(
description=desc,
expected_output=expected_output,
guardrail=obj.guardrail_fn,
)
# Class method
Task(
description=desc,
expected_output=expected_output,
guardrail=Object.guardrail_class_fn,
)
# Static method
Task(
description=desc,
expected_output=expected_output,
guardrail=Object.guardrail_static_fn,
)
def error_fn(x: TaskOutput, y: bool) -> tuple[bool, TaskOutput]:
return (y, x)
Task(
description=desc,
expected_output=expected_output,
guardrail=partial(error_fn, y=True),
)
with pytest.raises(ValidationError):
Task(
description=desc,
expected_output=expected_output,
guardrail=error_fn,
)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_output_pydantic_sequential():
class ScoreOutput(BaseModel):