mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
refactor: incorporate improvements from PR #2178 into guardrail validation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -194,19 +194,14 @@ class Task(BaseModel):
|
||||
"""
|
||||
if v is not None:
|
||||
sig = inspect.signature(v)
|
||||
# Get required positional parameters (excluding those with defaults)
|
||||
required_params = [
|
||||
# Check for exactly one required positional parameter
|
||||
positional_args = [
|
||||
param for param in sig.parameters.values()
|
||||
if param.default == inspect.Parameter.empty
|
||||
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
||||
if param.default is inspect.Parameter.empty
|
||||
]
|
||||
keyword_only_params = [
|
||||
param for param in sig.parameters.values()
|
||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
||||
]
|
||||
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
|
||||
if len(positional_args) != 1:
|
||||
raise GuardrailValidationError(
|
||||
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
|
||||
"Guardrail function must accept exactly one required parameter",
|
||||
{"params": [str(p) for p in sig.parameters.values()]}
|
||||
)
|
||||
|
||||
@@ -214,33 +209,27 @@ class Task(BaseModel):
|
||||
type_hints = typing.get_type_hints(v)
|
||||
return_annotation = type_hints.get('return')
|
||||
if return_annotation:
|
||||
# Convert annotation to string for comparison
|
||||
annotation_str = str(return_annotation).lower().replace(' ', '')
|
||||
|
||||
# Normalize type strings
|
||||
normalized_annotation = (
|
||||
annotation_str.replace('typing.', '')
|
||||
.replace('dict[str,typing.any]', 'dict[str,any]')
|
||||
.replace('dict[str, any]', 'dict[str,any]')
|
||||
)
|
||||
|
||||
VALID_RETURN_TYPES = {
|
||||
'tuple[bool,any]',
|
||||
'tuple[bool,str]',
|
||||
'tuple[bool,dict[str,any]]',
|
||||
'tuple[bool,taskoutput]'
|
||||
}
|
||||
|
||||
# Check if the normalized annotation matches any valid pattern
|
||||
is_valid = normalized_annotation == 'tuple[bool,any]'
|
||||
if not is_valid:
|
||||
is_valid = normalized_annotation in VALID_RETURN_TYPES
|
||||
|
||||
if not is_valid:
|
||||
# Simplified type checking logic
|
||||
return_annotation_args = typing.get_args(return_annotation)
|
||||
if not (
|
||||
typing.get_origin(return_annotation) is tuple
|
||||
and len(return_annotation_args) == 2
|
||||
and return_annotation_args[0] is bool
|
||||
and (
|
||||
return_annotation_args[1] is Any
|
||||
or return_annotation_args[1] is str
|
||||
or return_annotation_args[1] is TaskOutput
|
||||
or return_annotation_args[1] == Union[str, TaskOutput]
|
||||
or (
|
||||
typing.get_origin(return_annotation_args[1]) is dict
|
||||
and typing.get_args(return_annotation_args[1])[0] is str
|
||||
and typing.get_args(return_annotation_args[1])[1] is Any
|
||||
)
|
||||
)
|
||||
):
|
||||
raise GuardrailValidationError(
|
||||
f"Invalid return type annotation. Expected one of: "
|
||||
f"{', '.join(VALID_RETURN_TYPES)}",
|
||||
{"got": annotation_str}
|
||||
"Invalid return type annotation. Expected Tuple[bool, Any|str|TaskOutput|Dict[str, Any]]",
|
||||
{"got": str(return_annotation)}
|
||||
)
|
||||
return v
|
||||
|
||||
|
||||
Reference in New Issue
Block a user