feat: enhance task guardrail functionality and validation

* feat: enhance task guardrail functionality and validation

- Introduced support for multiple guardrails in the Task class, allowing for sequential processing of guardrails.
- Added a new `guardrails` field to the Task model to accept a list of callable guardrails or string descriptions.
- Implemented validation to ensure guardrails are processed correctly, including handling of retries and error messages.
- Enhanced the `_invoke_guardrail_function` method to manage guardrail execution and integrate with existing task output processing.
- Updated tests to cover various scenarios involving multiple guardrails, including success, failure, and retry mechanisms.

This update improves the flexibility and robustness of task execution by allowing for more complex validation scenarios.

* refactor: enhance guardrail type handling in Task model

- Updated the Task class to improve guardrail type definitions, introducing GuardrailType and GuardrailsType for better clarity and type safety.
- Simplified the validation logic for guardrails, ensuring that both single and multiple guardrails are processed correctly.
- Enhanced error messages for guardrail validation to provide clearer feedback when incorrect types are provided.
- This refactor improves the maintainability and robustness of task execution by standardizing guardrail handling.

* feat: implement per-guardrail retry tracking in Task model

- Introduced a new private attribute `_guardrail_retry_counts` to the Task class for tracking retry attempts on a per-guardrail basis.
- Updated the guardrail processing logic to utilize the new retry tracking, allowing for independent retry counts for each guardrail.
- Enhanced error handling to provide clearer feedback when guardrails fail validation after exceeding retry limits.
- Modified existing tests to validate the new retry tracking behavior, ensuring accurate assertions on guardrail retries.

This update improves the robustness and flexibility of task execution by allowing for more granular control over guardrail validation and retry mechanisms.
This commit is contained in:
Lorenze Jay
2025-10-18 15:46:30 -07:00
committed by GitHub
parent d7ac90c9e2
commit 09a390d50f
4 changed files with 683 additions and 136 deletions

View File

@@ -1,7 +1,7 @@
from collections.abc import Callable, Sequence
import shutil
import subprocess
import time
from collections.abc import Callable, Sequence
from typing import (
Any,
Literal,
@@ -876,6 +876,7 @@ class Agent(BaseAgent):
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
guardrail_max_retries=self.guardrail_max_retries,
)
return await lite_agent.kickoff_async(messages)

View File

@@ -1,15 +1,13 @@
import datetime
import inspect
import json
import logging
import threading
import uuid
import warnings
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
import datetime
from hashlib import md5
import inspect
import json
import logging
from pathlib import Path
import threading
from typing import (
Any,
ClassVar,
@@ -17,6 +15,8 @@ from typing import (
get_args,
get_origin,
)
import uuid
import warnings
from pydantic import (
UUID4,
@@ -42,11 +42,16 @@ from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.converter import Converter, convert_to_model
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.guardrail import (
GuardrailType,
GuardrailsType,
process_guardrail,
)
from crewai.utilities.i18n import I18N
from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import interpolate_only
_printer = Printer()
@@ -150,10 +155,15 @@ class Task(BaseModel):
default=None,
)
processed_by_agents: set[str] = Field(default_factory=set)
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field(
guardrail: GuardrailType = Field(
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
guardrails: GuardrailsType = Field(
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
@@ -234,6 +244,12 @@ class Task(BaseModel):
return v
_guardrail: Callable | None = PrivateAttr(default=None)
_guardrails: list[Callable[[TaskOutput], tuple[bool, Any]] | str] = PrivateAttr(
default=[]
)
_guardrail_retry_counts: dict[int, int] = PrivateAttr(
default_factory=dict,
)
_original_description: str | None = PrivateAttr(default=None)
_original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: str | None = PrivateAttr(default=None)
@@ -270,6 +286,50 @@ class Task(BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = []
if self.guardrails is not None:
if isinstance(self.guardrails, (list, tuple)):
if len(self.guardrails) > 0:
for guardrail in self.guardrails:
if callable(guardrail):
guardrails.append(guardrail)
elif isinstance(guardrail, str):
if self.agent is None:
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=guardrail, llm=self.agent.llm)
)
else:
raise ValueError("Guardrail must be a callable or a string")
else:
if callable(self.guardrails):
guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str):
if self.agent is None:
raise ValueError(
"Agent is required to use non-programmatic guardrails"
)
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=self.guardrails, llm=self.agent.llm)
)
else:
raise ValueError("Guardrail must be a callable or a string")
self._guardrails = guardrails
if self._guardrails:
self.guardrail = None
self._guardrail = None
return self
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
@@ -458,48 +518,24 @@ class Task(BaseModel):
output_format=self._get_output_format(),
)
if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
guardrail_index=idx,
)
# backwards support
if self._guardrail:
guardrail_result = process_guardrail(
output=task_output,
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -628,7 +664,10 @@ Follow these guidelines:
try:
crew_chat_messages = json.loads(crew_chat_messages_json)
except json.JSONDecodeError as e:
_printer.print(f"An error occurred while parsing crew chat messages: {e}", color="red")
_printer.print(
f"An error occurred while parsing crew chat messages: {e}",
color="red",
)
raise
conversation_history = "\n".join(
@@ -791,3 +830,101 @@ Follow these guidelines:
Fingerprint: The fingerprint of the task
"""
return self.security_config.fingerprint
def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: Callable | None,
guardrail_index: int | None = None,
) -> TaskOutput:
if not guardrail:
return task_output
if guardrail_index is not None:
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
else:
current_retry_count = self.retry_count
max_attempts = self.guardrail_max_retries + 1
for attempt in range(max_attempts):
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=current_retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if guardrail_result.success:
# Guardrail passed
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output
# Guardrail failed
if attempt >= self.guardrail_max_retries:
# Max retries reached
guardrail_name = (
f"guardrail {guardrail_index}"
if guardrail_index is not None
else "guardrail"
)
raise Exception(
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
if guardrail_index is not None:
current_retry_count += 1
self._guardrail_retry_counts[guardrail_index] = current_retry_count
else:
self.retry_count += 1
current_retry_count = self.retry_count
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
# Regenerate output from agent
result = agent.execute_task(
task=self,
context=context,
tools=tools,
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
)
return task_output

View File

@@ -1,17 +1,22 @@
from __future__ import annotations
from collections.abc import Callable
from typing import TYPE_CHECKING, Any
from collections.abc import Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Self
if TYPE_CHECKING:
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
GuardrailType: TypeAlias = Callable[["TaskOutput"], tuple[bool, Any]] | str | None
GuardrailsType: TypeAlias = Sequence[GuardrailType] | GuardrailType
class GuardrailResult(BaseModel):
"""Result from a task guardrail execution.

View File

@@ -1,7 +1,7 @@
import threading
from unittest.mock import Mock, patch
import pytest
from crewai import Agent, Task
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_types import (
@@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
def create_smart_task(**kwargs):
"""
Smart task factory that automatically assigns a mock agent when guardrails are present.
This maintains backward compatibility while handling the agent requirement for guardrails.
"""
guardrails_list = kwargs.get("guardrails")
has_guardrails = kwargs.get("guardrail") is not None or (
guardrails_list is not None and len(guardrails_list) > 0
)
if has_guardrails and kwargs.get("agent") is None:
kwargs["agent"] = Agent(
role="test_agent", goal="test_goal", backstory="test_backstory"
)
return Task(**kwargs)
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
@@ -21,7 +39,7 @@ def test_task_without_guardrail():
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output")
task = create_smart_task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func():
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
task = create_smart_task(
description="Test task", expected_output="Output", guardrail=guardrail
)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -57,7 +77,7 @@ def test_task_with_failing_guardrail():
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -84,7 +104,7 @@ def test_task_with_guardrail_retries():
agent.execute_task.return_value = "bad result"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -109,7 +129,7 @@ def test_guardrail_error_in_context():
agent.role = "test_agent"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -176,92 +196,78 @@ def test_task_guardrail_process_output(task_output):
def test_guardrail_emits_events(sample_agent):
started_guardrail = []
completed_guardrail = []
all_events_received = threading.Event()
expected_started = 3 # 2 from first task, 1 from second
expected_completed = 3 # 2 from first task, 1 from second
task1 = Task(
task = create_smart_task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
guardrail="Ensure the authors are from Italy",
)
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
completed_guardrail.append(
@crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event):
assert source == task
started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count}
)
@crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event):
assert source == task
completed_guardrail.append(
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
result = task.execute_sync(agent=sample_agent)
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
)
task.execute_sync(agent=sample_agent)
expected_started_events = [
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
{
"success": event.success,
"result": event.result,
"error": event.error,
"retry_count": event.retry_count,
}
)
if (
len(started_guardrail) >= expected_started
and len(completed_guardrail) >= expected_completed
):
all_events_received.set()
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
result = task1.execute_sync(agent=sample_agent)
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task2 = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
)
task2.execute_sync(agent=sample_agent)
# Wait for all events to be received
assert all_events_received.wait(timeout=10), (
"Timeout waiting for all guardrail events"
)
expected_started_events = [
{"guardrail": "Ensure the authors are from Italy", "retry_count": 0},
{"guardrail": "Ensure the authors are from Italy", "retry_count": 1},
{
"guardrail": """def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")""",
"retry_count": 0,
},
]
expected_completed_events = [
{
"success": False,
"result": None,
"error": "The task result does not comply with the guardrail because none of "
"the listed authors are from Italy. All authors mentioned are from "
"different countries, including Germany, the UK, the USA, and others, "
"which violates the requirement that authors must be Italian.",
"retry_count": 0,
},
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{
"success": True,
"result": "good result from callable function",
"error": None,
"retry_count": 0,
},
]
assert started_guardrail == expected_started_events
assert completed_guardrail == expected_completed_events
expected_completed_events = [
{
"success": False,
"result": None,
"error": "The task result does not comply with the guardrail because none of "
"the listed authors are from Italy. All authors mentioned are from "
"different countries, including Germany, the UK, the USA, and others, "
"which violates the requirement that authors must be Italian.",
"retry_count": 0,
},
{"success": True, "result": result.raw, "error": None, "retry_count": 1},
{
"success": True,
"result": "good result from callable function",
"error": None,
"retry_count": 0,
},
]
assert started_guardrail == expected_started_events
assert completed_guardrail == expected_completed_events
@pytest.mark.vcr(filter_headers=["authorization"])
@@ -276,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
match="Error while validating the task output: Unexpected error",
),
):
task = Task(
task = create_smart_task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -298,7 +304,7 @@ def test_hallucination_guardrail_integration():
context="Test reference context for validation", llm=mock_llm, threshold=8.0
)
task = Task(
task = create_smart_task(
description="Test task with hallucination guardrail",
expected_output="Valid output",
guardrail=guardrail,
@@ -318,3 +324,401 @@ def test_hallucination_guardrail_description_in_events():
event = LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=0)
assert event.guardrail == "HallucinationGuardrail (no-op)"
def test_multiple_guardrails_sequential_processing():
"""Test that multiple guardrails are processed sequentially."""
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""First guardrail adds prefix."""
return (True, f"[FIRST] {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Second guardrail adds suffix."""
return (True, f"{result.raw} [SECOND]")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Third guardrail converts to uppercase."""
return (True, result.raw.upper())
agent = Mock()
agent.role = "sequential_agent"
agent.execute_task.return_value = "original text"
agent.crew = None
task = create_smart_task(
description="Test sequential guardrails",
expected_output="Processed text",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "[FIRST] ORIGINAL TEXT [SECOND]"
def test_multiple_guardrails_with_validation_failure():
"""Test multiple guardrails where one fails validation."""
def length_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Ensure minimum length."""
if len(result.raw) < 10:
return (False, "Text too short")
return (True, result.raw)
def format_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Add formatting only if not already formatted."""
if not result.raw.startswith("Formatted:"):
return (True, f"Formatted: {result.raw}")
return (True, result.raw)
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final validation."""
if "Formatted:" not in result.raw:
return (False, "Missing formatting")
return (True, result.raw)
# Use a callable that tracks calls and returns appropriate values
call_count = 0
def mock_execute_task(*args, **kwargs):
nonlocal call_count
call_count += 1
result = (
"short"
if call_count == 1
else "this is a longer text that meets requirements"
)
return result
agent = Mock()
agent.role = "validation_agent"
agent.execute_task = mock_execute_task
agent.crew = None
task = create_smart_task(
description="Test guardrails with validation",
expected_output="Valid formatted text",
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
# The second call should be processed through all guardrails
assert result.raw == "Formatted: this is a longer text that meets requirements"
assert task._guardrail_retry_counts.get(0, 0) == 1
def test_multiple_guardrails_with_mixed_string_and_taskoutput():
"""Test guardrails that return both strings and TaskOutput objects."""
def string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Returns a string."""
return (True, f"String: {result.raw}")
def taskoutput_guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
"""Returns a TaskOutput object."""
new_output = TaskOutput(
name=result.name,
description=result.description,
expected_output=result.expected_output,
raw=f"TaskOutput: {result.raw}",
agent=result.agent,
output_format=result.output_format,
)
return (True, new_output)
def final_string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final string transformation."""
return (True, f"Final: {result.raw}")
agent = Mock()
agent.role = "mixed_agent"
agent.execute_task.return_value = "original"
agent.crew = None
task = create_smart_task(
description="Test mixed return types",
expected_output="Mixed processing",
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "Final: TaskOutput: String: original"
def test_multiple_guardrails_with_retry_on_middle_guardrail():
"""Test that retry works correctly when a middle guardrail fails."""
call_count = {"first": 0, "second": 0, "third": 0}
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["first"] += 1
return (True, f"First({call_count['first']}): {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Fails on first attempt, succeeds on second."""
call_count["second"] += 1
if call_count["second"] == 1:
return (False, "Second guardrail failed on first attempt")
return (True, f"Second({call_count['second']}): {result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["third"] += 1
return (True, f"Third({call_count['third']}): {result.raw}")
agent = Mock()
agent.role = "retry_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test retry in middle guardrail",
expected_output="Retry handling",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
assert task._guardrail_retry_counts.get(1, 0) == 1
assert call_count["first"] == 1
assert call_count["second"] == 2
assert call_count["third"] == 1
assert "Second(2)" in result.raw
def test_multiple_guardrails_with_max_retries_exceeded():
"""Test that exception is raised when max retries exceeded with multiple guardrails."""
def passing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always passes."""
return (True, f"Passed: {result.raw}")
def failing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always fails."""
return (False, "This guardrail always fails")
agent = Mock()
agent.role = "failing_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test max retries with multiple guardrails",
expected_output="Will fail",
guardrails=[passing_guardrail, failing_guardrail],
guardrail_max_retries=1,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail 1 validation after 1 retries" in str(exc_info.value)
assert "This guardrail always fails" in str(exc_info.value)
assert task._guardrail_retry_counts.get(1, 0) == 1
def test_multiple_guardrails_empty_list():
"""Test that empty guardrails list works correctly."""
agent = Mock()
agent.role = "empty_agent"
agent.execute_task.return_value = "no guardrails"
agent.crew = None
task = create_smart_task(
description="Test empty guardrails list",
expected_output="No processing",
guardrails=[],
)
result = task.execute_sync(agent=agent)
assert result.raw == "no guardrails"
def test_multiple_guardrails_with_llm_guardrails():
"""Test mixing callable and LLM guardrails."""
def callable_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Callable guardrail."""
return (True, f"Callable: {result.raw}")
# Create a proper mock agent without config issues
from crewai import Agent
agent = Agent(
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
)
task = create_smart_task(
description="Test mixed guardrail types",
expected_output="Mixed processing",
guardrails=[callable_guardrail, "Ensure the output is professional"],
agent=agent,
)
# The LLM guardrail will be converted to LLMGuardrail internally
assert len(task._guardrails) == 2
assert callable(task._guardrails[0])
assert callable(task._guardrails[1]) # LLMGuardrail is callable
def test_multiple_guardrails_processing_order():
"""Test that guardrails are processed in the correct order."""
processing_order = []
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("first")
return (True, f"1-{result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("second")
return (True, f"2-{result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("third")
return (True, f"3-{result.raw}")
agent = Mock()
agent.role = "order_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test processing order",
expected_output="Ordered processing",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert processing_order == ["first", "second", "third"]
assert result.raw == "3-2-1-base"
def test_multiple_guardrails_with_pydantic_output():
"""Test multiple guardrails with Pydantic output model."""
from pydantic import BaseModel, Field
class TestModel(BaseModel):
content: str = Field(description="The content")
processed: bool = Field(description="Whether it was processed")
def json_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Convert to JSON format."""
import json
data = {"content": result.raw, "processed": True}
return (True, json.dumps(data))
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Validate JSON structure."""
import json
try:
data = json.loads(result.raw)
if "content" not in data or "processed" not in data:
return (False, "Missing required fields")
return (True, result.raw)
except json.JSONDecodeError:
return (False, "Invalid JSON format")
agent = Mock()
agent.role = "pydantic_agent"
agent.execute_task.return_value = "test content"
agent.crew = None
task = create_smart_task(
description="Test guardrails with Pydantic",
expected_output="Structured output",
guardrails=[json_guardrail, validation_guardrail],
output_pydantic=TestModel,
)
result = task.execute_sync(agent=agent)
# Verify the result is valid JSON and can be parsed
import json
parsed = json.loads(result.raw)
assert parsed["content"] == "test content"
assert parsed["processed"] is True
def test_guardrails_vs_single_guardrail_mutual_exclusion():
"""Test that guardrails list nullifies single guardrail."""
def single_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Single guardrail - should be ignored."""
return (True, f"Single: {result.raw}")
def list_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""List guardrail - should be used."""
return (True, f"List: {result.raw}")
agent = Mock()
agent.role = "exclusion_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test mutual exclusion",
expected_output="Exclusion test",
guardrail=single_guardrail, # This should be ignored
guardrails=[list_guardrail], # This should be used
)
result = task.execute_sync(agent=agent)
# Should only use the guardrails list, not the single guardrail
assert result.raw == "List: test"
assert task._guardrail is None # Single guardrail should be nullified
def test_per_guardrail_independent_retry_tracking():
"""Test that each guardrail has independent retry tracking."""
call_counts = {"g1": 0, "g2": 0, "g3": 0}
def guardrail_1(result: TaskOutput) -> tuple[bool, str]:
"""Fails twice, then succeeds."""
call_counts["g1"] += 1
if call_counts["g1"] <= 2:
return (False, "Guardrail 1 not ready yet")
return (True, f"G1({call_counts['g1']}): {result.raw}")
def guardrail_2(result: TaskOutput) -> tuple[bool, str]:
"""Fails once, then succeeds."""
call_counts["g2"] += 1
if call_counts["g2"] == 1:
return (False, "Guardrail 2 not ready yet")
return (True, f"G2({call_counts['g2']}): {result.raw}")
def guardrail_3(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_counts["g3"] += 1
return (True, f"G3({call_counts['g3']}): {result.raw}")
agent = Mock()
agent.role = "independent_retry_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test independent retry tracking",
expected_output="Independent retries",
guardrails=[guardrail_1, guardrail_2, guardrail_3],
guardrail_max_retries=3,
)
result = task.execute_sync(agent=agent)
assert task._guardrail_retry_counts.get(0, 0) == 2
assert task._guardrail_retry_counts.get(1, 0) == 1
assert task._guardrail_retry_counts.get(2, 0) == 0
assert call_counts["g1"] == 3
assert call_counts["g2"] == 2
assert call_counts["g3"] == 1
assert "G3(1)" in result.raw