mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
* Fix: More confortable validation * Fix: union type support --------- Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
@@ -19,6 +19,8 @@ from typing import (
|
|||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
Union,
|
Union,
|
||||||
|
get_args,
|
||||||
|
get_origin,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@@ -178,15 +180,29 @@ class Task(BaseModel):
|
|||||||
"""
|
"""
|
||||||
if v is not None:
|
if v is not None:
|
||||||
sig = inspect.signature(v)
|
sig = inspect.signature(v)
|
||||||
if len(sig.parameters) != 1:
|
positional_args = [
|
||||||
|
param
|
||||||
|
for param in sig.parameters.values()
|
||||||
|
if param.default is inspect.Parameter.empty
|
||||||
|
]
|
||||||
|
if len(positional_args) != 1:
|
||||||
raise ValueError("Guardrail function must accept exactly one parameter")
|
raise ValueError("Guardrail function must accept exactly one parameter")
|
||||||
|
|
||||||
# Check return annotation if present, but don't require it
|
# Check return annotation if present, but don't require it
|
||||||
return_annotation = sig.return_annotation
|
return_annotation = sig.return_annotation
|
||||||
if return_annotation != inspect.Signature.empty:
|
if return_annotation != inspect.Signature.empty:
|
||||||
|
|
||||||
|
return_annotation_args = get_args(return_annotation)
|
||||||
if not (
|
if not (
|
||||||
return_annotation == Tuple[bool, Any]
|
get_origin(return_annotation) is tuple
|
||||||
or str(return_annotation) == "Tuple[bool, Any]"
|
and len(return_annotation_args) == 2
|
||||||
|
and return_annotation_args[0] is bool
|
||||||
|
and (
|
||||||
|
return_annotation_args[1] is Any
|
||||||
|
or return_annotation_args[1] is str
|
||||||
|
or return_annotation_args[1] is TaskOutput
|
||||||
|
or return_annotation_args[1] == Union[str, TaskOutput]
|
||||||
|
)
|
||||||
):
|
):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"If return type is annotated, it must be Tuple[bool, Any]"
|
"If return type is annotated, it must be Tuple[bool, Any]"
|
||||||
|
|||||||
@@ -3,6 +3,8 @@
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
|
from typing import Tuple, Union
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@@ -215,6 +217,75 @@ def test_multiple_output_type_error():
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_guardrail_type_error():
|
||||||
|
desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting."
|
||||||
|
expected_output = "Bullet point list of 5 interesting ideas."
|
||||||
|
# Lambda function
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=lambda x: (True, x),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Function
|
||||||
|
def guardrail_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||||
|
return (True, x)
|
||||||
|
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=guardrail_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
class Object:
|
||||||
|
def guardrail_fn(self, x: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||||
|
return (True, x)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]:
|
||||||
|
return (True, x)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
|
||||||
|
return (True, x)
|
||||||
|
|
||||||
|
obj = Object()
|
||||||
|
# Method
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=obj.guardrail_fn,
|
||||||
|
)
|
||||||
|
# Class method
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=Object.guardrail_class_fn,
|
||||||
|
)
|
||||||
|
# Static method
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=Object.guardrail_static_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
|
||||||
|
return (y, x)
|
||||||
|
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=partial(error_fn, y=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
Task(
|
||||||
|
description=desc,
|
||||||
|
expected_output=expected_output,
|
||||||
|
guardrail=error_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_output_pydantic_sequential():
|
def test_output_pydantic_sequential():
|
||||||
class ScoreOutput(BaseModel):
|
class ScoreOutput(BaseModel):
|
||||||
|
|||||||
Reference in New Issue
Block a user