mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor: incorporate improvements from PR #2178 into guardrail validation
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -194,19 +194,14 @@ class Task(BaseModel):
|
|||||||
"""
|
"""
|
||||||
if v is not None:
|
if v is not None:
|
||||||
sig = inspect.signature(v)
|
sig = inspect.signature(v)
|
||||||
# Get required positional parameters (excluding those with defaults)
|
# Check for exactly one required positional parameter
|
||||||
required_params = [
|
positional_args = [
|
||||||
param for param in sig.parameters.values()
|
param for param in sig.parameters.values()
|
||||||
if param.default == inspect.Parameter.empty
|
if param.default is inspect.Parameter.empty
|
||||||
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
||||||
]
|
]
|
||||||
keyword_only_params = [
|
if len(positional_args) != 1:
|
||||||
param for param in sig.parameters.values()
|
|
||||||
if param.kind == inspect.Parameter.KEYWORD_ONLY
|
|
||||||
]
|
|
||||||
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
|
|
||||||
raise GuardrailValidationError(
|
raise GuardrailValidationError(
|
||||||
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
|
"Guardrail function must accept exactly one required parameter",
|
||||||
{"params": [str(p) for p in sig.parameters.values()]}
|
{"params": [str(p) for p in sig.parameters.values()]}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -214,33 +209,27 @@ class Task(BaseModel):
|
|||||||
type_hints = typing.get_type_hints(v)
|
type_hints = typing.get_type_hints(v)
|
||||||
return_annotation = type_hints.get('return')
|
return_annotation = type_hints.get('return')
|
||||||
if return_annotation:
|
if return_annotation:
|
||||||
# Convert annotation to string for comparison
|
# Simplified type checking logic
|
||||||
annotation_str = str(return_annotation).lower().replace(' ', '')
|
return_annotation_args = typing.get_args(return_annotation)
|
||||||
|
if not (
|
||||||
# Normalize type strings
|
typing.get_origin(return_annotation) is tuple
|
||||||
normalized_annotation = (
|
and len(return_annotation_args) == 2
|
||||||
annotation_str.replace('typing.', '')
|
and return_annotation_args[0] is bool
|
||||||
.replace('dict[str,typing.any]', 'dict[str,any]')
|
and (
|
||||||
.replace('dict[str, any]', 'dict[str,any]')
|
return_annotation_args[1] is Any
|
||||||
)
|
or return_annotation_args[1] is str
|
||||||
|
or return_annotation_args[1] is TaskOutput
|
||||||
VALID_RETURN_TYPES = {
|
or return_annotation_args[1] == Union[str, TaskOutput]
|
||||||
'tuple[bool,any]',
|
or (
|
||||||
'tuple[bool,str]',
|
typing.get_origin(return_annotation_args[1]) is dict
|
||||||
'tuple[bool,dict[str,any]]',
|
and typing.get_args(return_annotation_args[1])[0] is str
|
||||||
'tuple[bool,taskoutput]'
|
and typing.get_args(return_annotation_args[1])[1] is Any
|
||||||
}
|
)
|
||||||
|
)
|
||||||
# Check if the normalized annotation matches any valid pattern
|
):
|
||||||
is_valid = normalized_annotation == 'tuple[bool,any]'
|
|
||||||
if not is_valid:
|
|
||||||
is_valid = normalized_annotation in VALID_RETURN_TYPES
|
|
||||||
|
|
||||||
if not is_valid:
|
|
||||||
raise GuardrailValidationError(
|
raise GuardrailValidationError(
|
||||||
f"Invalid return type annotation. Expected one of: "
|
"Invalid return type annotation. Expected Tuple[bool, Any|str|TaskOutput|Dict[str, Any]]",
|
||||||
f"{', '.join(VALID_RETURN_TYPES)}",
|
{"got": str(return_annotation)}
|
||||||
{"got": annotation_str}
|
|
||||||
)
|
)
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user