mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
* Fix: More confortable validation * Fix: union type support --------- Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
This commit is contained in:
@@ -3,6 +3,8 @@
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Tuple, Union
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -215,6 +217,75 @@ def test_multiple_output_type_error():
|
||||
)
|
||||
|
||||
|
||||
def test_guardrail_type_error():
|
||||
desc = "Give me a list of 5 interesting ideas to explore for na article, what makes them unique and interesting."
|
||||
expected_output = "Bullet point list of 5 interesting ideas."
|
||||
# Lambda function
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=lambda x: (True, x),
|
||||
)
|
||||
|
||||
# Function
|
||||
def guardrail_fn(x: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||
return (True, x)
|
||||
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=guardrail_fn,
|
||||
)
|
||||
|
||||
class Object:
|
||||
def guardrail_fn(self, x: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||
return (True, x)
|
||||
|
||||
@classmethod
|
||||
def guardrail_class_fn(cls, x: TaskOutput) -> tuple[bool, str]:
|
||||
return (True, x)
|
||||
|
||||
@staticmethod
|
||||
def guardrail_static_fn(x: TaskOutput) -> tuple[bool, Union[str, TaskOutput]]:
|
||||
return (True, x)
|
||||
|
||||
obj = Object()
|
||||
# Method
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=obj.guardrail_fn,
|
||||
)
|
||||
# Class method
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=Object.guardrail_class_fn,
|
||||
)
|
||||
# Static method
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=Object.guardrail_static_fn,
|
||||
)
|
||||
|
||||
def error_fn(x: TaskOutput, y: bool) -> Tuple[bool, TaskOutput]:
|
||||
return (y, x)
|
||||
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=partial(error_fn, y=True),
|
||||
)
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
Task(
|
||||
description=desc,
|
||||
expected_output=expected_output,
|
||||
guardrail=error_fn,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_output_pydantic_sequential():
|
||||
class ScoreOutput(BaseModel):
|
||||
|
||||
Reference in New Issue
Block a user