refactor: incorporate improvements from PR #2178 into guardrail validation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-21 03:24:21 +00:00
parent f3f094faad
commit 652f3f5b8e

View File

@@ -194,19 +194,14 @@ class Task(BaseModel):
"""
if v is not None:
sig = inspect.signature(v)
# Get required positional parameters (excluding those with defaults)
required_params = [
# Check for exactly one required positional parameter
positional_args = [
param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
if param.default is inspect.Parameter.empty
]
keyword_only_params = [
param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
if len(positional_args) != 1:
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
"Guardrail function must accept exactly one required parameter",
{"params": [str(p) for p in sig.parameters.values()]}
)
@@ -214,33 +209,27 @@ class Task(BaseModel):
type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison
annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = {
'tuple[bool,any]',
'tuple[bool,str]',
'tuple[bool,dict[str,any]]',
'tuple[bool,taskoutput]'
}
# Check if the normalized annotation matches any valid pattern
is_valid = normalized_annotation == 'tuple[bool,any]'
if not is_valid:
is_valid = normalized_annotation in VALID_RETURN_TYPES
if not is_valid:
# Simplified type checking logic
return_annotation_args = typing.get_args(return_annotation)
if not (
typing.get_origin(return_annotation) is tuple
and len(return_annotation_args) == 2
and return_annotation_args[0] is bool
and (
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput]
or (
typing.get_origin(return_annotation_args[1]) is dict
and typing.get_args(return_annotation_args[1])[0] is str
and typing.get_args(return_annotation_args[1])[1] is Any
)
)
):
raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES)}",
{"got": annotation_str}
"Invalid return type annotation. Expected Tuple[bool, Any|str|TaskOutput|Dict[str, Any]]",
{"got": str(return_annotation)}
)
return v