refactor: implement code review suggestions

- Add GuardrailValidationError exception
- Use typing.get_type_hints for better type checking
- Add descriptive error messages with context
- Support dict return type

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-20 16:50:33 +00:00
parent 3c5672f4ec
commit 0e086d348a
2 changed files with 67 additions and 18 deletions

View File

@@ -3,6 +3,7 @@ import inspect
import json import json
import logging import logging
import threading import threading
import typing
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy from copy import copy
@@ -32,6 +33,7 @@ from pydantic import (
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.guardrail_result import GuardrailResult from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -151,8 +153,8 @@ class Task(BaseModel):
"""Validate that the guardrail function has the correct signature and behavior. """Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by: While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one parameter (the TaskOutput) 1. Verifying the function accepts exactly one required positional parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present 2. Checking return type annotations match tuple[bool, Any] or specific types like tuple[bool, str]
3. Providing clear, immediate error messages for debugging 3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because: This runtime validation is crucial because:
@@ -160,6 +162,24 @@ class Task(BaseModel):
- Function signatures need immediate validation before task execution - Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues - Clear error messages help users debug guardrail implementation issues
Examples:
Simple validation with new style annotation:
>>> def validate_output(result: TaskOutput) -> tuple[bool, str]:
... return (True, result.raw.upper())
Validation with optional parameters:
>>> def validate_with_options(result: TaskOutput, strict: bool = True) -> tuple[bool, str]:
... if strict and not result.raw.isupper():
... return (False, "Text must be uppercase")
... return (True, result.raw)
Validation with specific return type:
>>> def validate_task_output(result: TaskOutput) -> tuple[bool, TaskOutput]:
... if not result.raw:
... return (False, result)
... result.raw = result.raw.strip()
... return (True, result)
Args: Args:
v: The guardrail function to validate v: The guardrail function to validate
@@ -168,33 +188,37 @@ class Task(BaseModel):
Raises: Raises:
ValueError: If the function signature is invalid or return annotation ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any] doesn't match tuple[bool, Any] or specific allowed types
""" """
if v is not None: if v is not None:
sig = inspect.signature(v) sig = inspect.signature(v)
# Get required parameters (excluding those with defaults) # Get required positional parameters (excluding those with defaults)
required_params = [ required_params = [
param for param in sig.parameters.values() param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
] ]
if len(required_params) != 1: if len(required_params) != 1:
raise ValueError("Guardrail function must accept exactly one required parameter") raise ValueError("Guardrail function must accept exactly one required positional parameter")
# Check return annotation if present, but don't require it # Check return annotation if present, but don't require it
return_annotation = sig.return_annotation type_hints = typing.get_type_hints(v)
if return_annotation != inspect.Signature.empty: return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison # Convert annotation to string for comparison
annotation_str = str(return_annotation).lower() annotation_str = str(return_annotation).lower()
valid_patterns = [ VALID_RETURN_TYPES = {
'tuple[bool, any]', 'tuple[bool, any]': True,
'typing.tuple[bool, any]', 'typing.tuple[bool, any]': True,
'tuple[bool, str]', 'tuple[bool, str]': True,
'tuple[bool, taskoutput]' 'tuple[bool, dict]': True,
] 'tuple[bool, taskoutput]': True
if not any(pattern in annotation_str for pattern in valid_patterns): }
raise ValueError( if not any(pattern in annotation_str for pattern in VALID_RETURN_TYPES):
"Return type must be tuple[bool, Any] or a specific type like " raise GuardrailValidationError(
"tuple[bool, str] or tuple[bool, TaskOutput]" f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES.keys())}",
{"got": annotation_str}
) )
return v return v

View File

@@ -0,0 +1,25 @@
"""
Module for task-related exceptions.
This module provides custom exceptions used throughout the task system
to provide more specific error handling and context.
"""
from typing import Any, Dict, Optional
class GuardrailValidationError(Exception):
"""Exception raised for guardrail validation errors.
This exception provides detailed context about why a guardrail
validation failed, including the specific validation that failed
and any relevant context information.
Attributes:
message: A clear description of the validation error
context: Optional dictionary containing additional error context
"""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
self.message = message
self.context = context or {}
super().__init__(self.message)