refactor: implement code review suggestions

- Add GuardrailValidationError exception
- Use typing.get_type_hints for better type checking
- Add descriptive error messages with context
- Support dict return type

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-20 16:50:33 +00:00
parent 3c5672f4ec
commit 0e086d348a
2 changed files with 67 additions and 18 deletions

View File

@@ -3,6 +3,7 @@ import inspect
import json
import logging
import threading
import typing
import uuid
from concurrent.futures import Future
from copy import copy
@@ -32,6 +33,7 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -151,8 +153,8 @@ class Task(BaseModel):
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present
1. Verifying the function accepts exactly one required positional parameter (the TaskOutput)
2. Checking return type annotations match tuple[bool, Any] or specific types like tuple[bool, str]
3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because:
@@ -160,6 +162,24 @@ class Task(BaseModel):
- Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues
Examples:
Simple validation with new style annotation:
>>> def validate_output(result: TaskOutput) -> tuple[bool, str]:
... return (True, result.raw.upper())
Validation with optional parameters:
>>> def validate_with_options(result: TaskOutput, strict: bool = True) -> tuple[bool, str]:
... if strict and not result.raw.isupper():
... return (False, "Text must be uppercase")
... return (True, result.raw)
Validation with specific return type:
>>> def validate_task_output(result: TaskOutput) -> tuple[bool, TaskOutput]:
... if not result.raw:
... return (False, result)
... result.raw = result.raw.strip()
... return (True, result)
Args:
v: The guardrail function to validate
@@ -168,33 +188,37 @@ class Task(BaseModel):
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any]
doesn't match tuple[bool, Any] or specific allowed types
"""
if v is not None:
sig = inspect.signature(v)
# Get required parameters (excluding those with defaults)
# Get required positional parameters (excluding those with defaults)
required_params = [
param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty
if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
if len(required_params) != 1:
raise ValueError("Guardrail function must accept exactly one required parameter")
raise ValueError("Guardrail function must accept exactly one required positional parameter")
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison
annotation_str = str(return_annotation).lower()
valid_patterns = [
'tuple[bool, any]',
'typing.tuple[bool, any]',
'tuple[bool, str]',
'tuple[bool, taskoutput]'
]
if not any(pattern in annotation_str for pattern in valid_patterns):
raise ValueError(
"Return type must be tuple[bool, Any] or a specific type like "
"tuple[bool, str] or tuple[bool, TaskOutput]"
VALID_RETURN_TYPES = {
'tuple[bool, any]': True,
'typing.tuple[bool, any]': True,
'tuple[bool, str]': True,
'tuple[bool, dict]': True,
'tuple[bool, taskoutput]': True
}
if not any(pattern in annotation_str for pattern in VALID_RETURN_TYPES):
raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES.keys())}",
{"got": annotation_str}
)
return v

View File

@@ -0,0 +1,25 @@
"""
Module for task-related exceptions.
This module provides custom exceptions used throughout the task system
to provide more specific error handling and context.
"""
from typing import Any, Dict, Optional
class GuardrailValidationError(Exception):
"""Exception raised for guardrail validation errors.
This exception provides detailed context about why a guardrail
validation failed, including the specific validation that failed
and any relevant context information.
Attributes:
message: A clear description of the validation error
context: Optional dictionary containing additional error context
"""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
self.message = message
self.context = context or {}
super().__init__(self.message)