Compare commits

...

23 Commits

Author SHA1 Message Date
João Moura
828a567017 Merge branch 'main' into devin/1740069574-improve-guardrail-validation 2025-02-21 00:27:19 -03:00
Devin AI
f3f094faad fix: sort imports in task.py
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:41:42 +00:00
Devin AI
ad0db27040 fix: use Any type for TaskOutput.raw to support both string and dict
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:38:32 +00:00
Devin AI
6cc37a38a6 fix: add missing TaskOutput import
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:37:13 +00:00
Devin AI
4bafdacd88 fix: allow both string and dict types for TaskOutput.raw
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:35:47 +00:00
Devin AI
2793a8ee87 fix: preserve dictionary type in guardrail results
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:32:36 +00:00
Devin AI
a9b0702cbe fix: type error in dictionary result handling
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:28:36 +00:00
Devin AI
54bdc8b52c fix: properly handle dictionary results in guardrails
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:26:54 +00:00
Devin AI
578164cf05 fix: preserve dictionary type in guardrail results
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:25:44 +00:00
Devin AI
fe78553c9c fix: ensure dict results are properly serialized to string
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:13:27 +00:00
Devin AI
0cb47a8d0a fix: improve Task.copy() method to handle both legacy and new style copy
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:12:02 +00:00
Devin AI
a87a7d2833 fix: restore backward compatibility for Task.copy() and fix dict handling
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:09:32 +00:00
Devin AI
5b833d932e fix: correct type error in error message formatting
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:04:35 +00:00
Devin AI
6a59194e6f fix: improve type validation logic to fix type checker error
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:03:22 +00:00
Devin AI
5f3eb3605a fix: remove duplicate return statement in copy method
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:02:06 +00:00
Devin AI
77990b6293 fix: resolve type checking errors and improve copy method
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:01:10 +00:00
Devin AI
050ead62a7 fix: resolve remaining type checking errors
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 17:00:35 +00:00
Devin AI
8f3936eb09 fix: resolve type checking errors
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:58:01 +00:00
Devin AI
9c1f24ee26 fix: improve type validation logic
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:57:10 +00:00
Devin AI
24da7be540 fix: sort imports in test file
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:57:07 +00:00
Devin AI
31e8b9d7f2 refactor: implement code review suggestions
- Use typing.get_type_hints for better type checking
- Add proper handling of dict return types
- Improve parameter validation for keyword-only params
- Add comprehensive test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:54:27 +00:00
Devin AI
0e086d348a refactor: implement code review suggestions
- Add GuardrailValidationError exception
- Use typing.get_type_hints for better type checking
- Add descriptive error messages with context
- Support dict return type

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:50:33 +00:00
Devin AI
3c5672f4ec feat: improve guardrail validation support
- Add support for new style tuple annotations
- Allow specific return types like tuple[bool, str]
- Support optional parameters in guardrail functions

Fixes #2177

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-20 16:43:02 +00:00
4 changed files with 292 additions and 118 deletions

View File

@@ -3,17 +3,20 @@ import inspect
import json
import logging
import threading
import typing
import uuid
from concurrent.futures import Future
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -32,6 +35,7 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -113,7 +117,7 @@ class Task(BaseModel):
description="Task output, it's final result after being executed", default=None
)
tools: Optional[List[BaseTool]] = Field(
default_factory=list,
default_factory=list[BaseTool],
description="Tools the agent is limited to use for this task.",
)
id: UUID4 = Field(
@@ -129,7 +133,7 @@ class Task(BaseModel):
description="A converter class used to export structured output",
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
processed_by_agents: Set[str] = Field(default_factory=set[str])
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task",
@@ -151,8 +155,8 @@ class Task(BaseModel):
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present
1. Verifying the function accepts exactly one required positional parameter (the TaskOutput)
2. Checking return type annotations match tuple[bool, Any] or specific types like tuple[bool, str]
3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because:
@@ -160,6 +164,24 @@ class Task(BaseModel):
- Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues
Examples:
Simple validation with new style annotation:
>>> def validate_output(result: TaskOutput) -> tuple[bool, str]:
... return (True, result.raw.upper())
Validation with optional parameters:
>>> def validate_with_options(result: TaskOutput, strict: bool = True) -> tuple[bool, str]:
... if strict and not result.raw.isupper():
... return (False, "Text must be uppercase")
... return (True, result.raw)
Validation with specific return type:
>>> def validate_task_output(result: TaskOutput) -> tuple[bool, TaskOutput]:
... if not result.raw:
... return (False, result)
... result.raw = result.raw.strip()
... return (True, result)
Args:
v: The guardrail function to validate
@@ -168,22 +190,57 @@ class Task(BaseModel):
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any]
doesn't match tuple[bool, Any] or specific allowed types
"""
if v is not None:
sig = inspect.signature(v)
if len(sig.parameters) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Get required positional parameters (excluding those with defaults)
required_params = [
param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
keyword_only_params = [
param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
{"params": [str(p) for p in sig.parameters.values()]}
)
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
if not (
return_annotation == Tuple[bool, Any]
or str(return_annotation) == "Tuple[bool, Any]"
):
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison
annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = {
'tuple[bool,any]',
'tuple[bool,str]',
'tuple[bool,dict[str,any]]',
'tuple[bool,taskoutput]'
}
# Check if the normalized annotation matches any valid pattern
is_valid = normalized_annotation == 'tuple[bool,any]'
if not is_valid:
is_valid = normalized_annotation in VALID_RETURN_TYPES
if not is_valid:
raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES)}",
{"got": annotation_str}
)
return v
@@ -411,6 +468,7 @@ class Task(BaseModel):
"Task guardrail returned None as result. This is not allowed."
)
# Handle different result types
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
@@ -420,6 +478,13 @@ class Task(BaseModel):
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
elif isinstance(guardrail_result.result, dict):
task_output.raw = guardrail_result.result
task_output.json_dict = guardrail_result.result
pydantic_output, _ = self._export_output(
json.dumps(guardrail_result.result)
)
task_output.pydantic = pydantic_output
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -610,40 +675,74 @@ class Task(BaseModel):
self.delegations += 1
def copy(
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
self,
agents: List["BaseAgent"] | None = None,
task_mapping: Dict[str, "Task"] | None = None,
*,
include: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
exclude: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
update: dict[str, Any] | None = None,
deep: bool = False,
) -> "Task":
"""Create a deep copy of the Task."""
exclude = {
"""Create a deep copy of the Task.
Args:
agents: Optional list of agents to copy agent references
task_mapping: Optional mapping of task keys to tasks for context
include: Fields to include in the copy
exclude: Fields to exclude from the copy
update: Fields to update in the copy
deep: Whether to perform a deep copy
"""
if agents is None and task_mapping is None:
# New style copy using BaseModel
copied = super().copy(
include=include,
exclude=exclude,
update=update,
deep=deep,
)
# Copy mutable fields
if self.tools:
copied.tools = copy(self.tools)
if self.context:
copied.context = copy(self.context)
return copied
# Legacy copy behavior
exclude_fields = {
"id",
"agent",
"context",
"tools",
}
copied_data = self.model_dump(exclude=exclude)
copied_data = self.model_dump(exclude=exclude_fields)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
cloned_context = (
[task_mapping[context_task.key] for context_task in self.context]
if self.context
if self.context and task_mapping
else None
)
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
if not agents:
return None
return next((agent for agent in agents if agent.role == role), None)
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else []
copied_task = Task(
return Task(
**copied_data,
context=cloned_context,
agent=cloned_agent,
tools=cloned_tools,
)
return copied_task
def _export_output(
self, result: str
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]:

View File

@@ -0,0 +1,25 @@
"""
Module for task-related exceptions.
This module provides custom exceptions used throughout the task system
to provide more specific error handling and context.
"""
from typing import Any, Dict, Optional
class GuardrailValidationError(Exception):
"""Exception raised for guardrail validation errors.
This exception provides detailed context about why a guardrail
validation failed, including the specific validation that failed
and any relevant context information.
Attributes:
message: A clear description of the validation error
context: Optional dictionary containing additional error context
"""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
self.message = message
self.context = context or {}
super().__init__(self.message)

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Dict, Optional, Union
from pydantic import BaseModel, Field, model_validator
@@ -15,7 +15,7 @@ class TaskOutput(BaseModel):
description="Expected output of the task", default=None
)
summary: Optional[str] = Field(description="Summary of the task", default=None)
raw: str = Field(description="Raw output of the task", default="")
raw: Any = Field(description="Raw output of the task", default="")
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of task", default=None
)

View File

@@ -1,129 +1,179 @@
"""Tests for task guardrails functionality."""
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from crewai.task import Task
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.task_output import TaskOutput
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
class TestTaskGuardrails:
"""Test suite for task guardrail functionality."""
task = Task(description="Test task", expected_output="Output")
@pytest.fixture
def mock_agent(self):
"""Fixture providing a mock agent for testing."""
agent = Mock()
agent.role = "test_agent"
agent.crew = None
return agent
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
def test_task_without_guardrail(self, mock_agent):
"""Test that tasks work normally without guardrails (backward compatibility)."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
def test_task_with_successful_guardrail():
"""Test that successful guardrail validation passes transformed result."""
def test_task_with_successful_guardrail(self, mock_agent):
"""Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail():
"""Test that failing guardrail triggers retry with error context."""
def test_task_with_failing_guardrail(self, mock_agent):
"""Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def guardrail(result: TaskOutput):
return (False, "Invalid format")
mock_agent.execute_task.side_effect = ["bad result", "good result"]
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
# First execution fails guardrail, second succeeds
mock_agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
# First execution fails guardrail, second succeeds
agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
def test_task_with_guardrail_retries():
"""Test that guardrail respects max_retries configuration."""
def test_task_with_guardrail_retries(self, mock_agent):
"""Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def guardrail(result: TaskOutput):
return (False, "Invalid format")
mock_agent.execute_task.return_value = "bad result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "bad result"
agent.crew = None
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context():
"""Test that guardrail error is passed in context for retry."""
def test_guardrail_error_in_context(self, mock_agent):
"""Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
agent = Mock()
agent.role = "test_agent"
agent.crew = None
# Mock execute_task to succeed on second attempt
first_call = True
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
mock_agent.execute_task.side_effect = execute_task
# Mock execute_task to succeed on second attempt
first_call = True
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)
agent.execute_task.side_effect = execute_task
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
def test_guardrail_with_new_style_annotation(self, mock_agent):
"""Test guardrail with new style tuple annotation."""
def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper())
mock_agent.execute_task.return_value = "test result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_guardrail_with_optional_params(self, mock_agent):
"""Test guardrail with optional parameters."""
def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}")
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test-default"
def test_guardrail_with_invalid_optional_params(self, mock_agent):
"""Test guardrail with invalid optional parameters."""
def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]:
return (True, result.raw)
with pytest.raises(GuardrailValidationError) as exc_info:
Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
assert "exactly one required positional parameter" in str(exc_info.value)
def test_guardrail_with_dict_return_type(self, mock_agent):
"""Test guardrail with dict return type."""
def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]:
return (True, {"processed": result.raw.upper()})
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == {"processed": "TEST"}