refactor: incorporate improvements from PR #2178 into guardrail validation

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-21 03:24:21 +00:00
parent f3f094faad
commit 652f3f5b8e

View File

@@ -194,19 +194,14 @@ class Task(BaseModel):
""" """
if v is not None: if v is not None:
sig = inspect.signature(v) sig = inspect.signature(v)
# Get required positional parameters (excluding those with defaults) # Check for exactly one required positional parameter
required_params = [ positional_args = [
param for param in sig.parameters.values() param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty if param.default is inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
] ]
keyword_only_params = [ if len(positional_args) != 1:
param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError( raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters", "Guardrail function must accept exactly one required parameter",
{"params": [str(p) for p in sig.parameters.values()]} {"params": [str(p) for p in sig.parameters.values()]}
) )
@@ -214,33 +209,27 @@ class Task(BaseModel):
type_hints = typing.get_type_hints(v) type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return') return_annotation = type_hints.get('return')
if return_annotation: if return_annotation:
# Convert annotation to string for comparison # Simplified type checking logic
annotation_str = str(return_annotation).lower().replace(' ', '') return_annotation_args = typing.get_args(return_annotation)
if not (
# Normalize type strings typing.get_origin(return_annotation) is tuple
normalized_annotation = ( and len(return_annotation_args) == 2
annotation_str.replace('typing.', '') and return_annotation_args[0] is bool
.replace('dict[str,typing.any]', 'dict[str,any]') and (
.replace('dict[str, any]', 'dict[str,any]') return_annotation_args[1] is Any
) or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
VALID_RETURN_TYPES = { or return_annotation_args[1] == Union[str, TaskOutput]
'tuple[bool,any]', or (
'tuple[bool,str]', typing.get_origin(return_annotation_args[1]) is dict
'tuple[bool,dict[str,any]]', and typing.get_args(return_annotation_args[1])[0] is str
'tuple[bool,taskoutput]' and typing.get_args(return_annotation_args[1])[1] is Any
} )
)
# Check if the normalized annotation matches any valid pattern ):
is_valid = normalized_annotation == 'tuple[bool,any]'
if not is_valid:
is_valid = normalized_annotation in VALID_RETURN_TYPES
if not is_valid:
raise GuardrailValidationError( raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: " "Invalid return type annotation. Expected Tuple[bool, Any|str|TaskOutput|Dict[str, Any]]",
f"{', '.join(VALID_RETURN_TYPES)}", {"got": str(return_annotation)}
{"got": annotation_str}
) )
return v return v