diff --git a/lib/crewai/src/crewai/agent.py b/lib/crewai/src/crewai/agent.py index dc96e1bf3..d6e1f50d4 100644 --- a/lib/crewai/src/crewai/agent.py +++ b/lib/crewai/src/crewai/agent.py @@ -1,7 +1,7 @@ +from collections.abc import Callable, Sequence import shutil import subprocess import time -from collections.abc import Callable, Sequence from typing import ( Any, Literal, @@ -876,6 +876,7 @@ class Agent(BaseAgent): i18n=self.i18n, original_agent=self, guardrail=self.guardrail, + guardrail_max_retries=self.guardrail_max_retries, ) return await lite_agent.kickoff_async(messages) diff --git a/lib/crewai/src/crewai/task.py b/lib/crewai/src/crewai/task.py index 1f757fd59..56dbc433f 100644 --- a/lib/crewai/src/crewai/task.py +++ b/lib/crewai/src/crewai/task.py @@ -1,15 +1,13 @@ -import datetime -import inspect -import json -import logging -import threading -import uuid -import warnings from collections.abc import Callable from concurrent.futures import Future from copy import copy as shallow_copy +import datetime from hashlib import md5 +import inspect +import json +import logging from pathlib import Path +import threading from typing import ( Any, ClassVar, @@ -17,6 +15,8 @@ from typing import ( get_args, get_origin, ) +import uuid +import warnings from pydantic import ( UUID4, @@ -42,11 +42,16 @@ from crewai.tools.base_tool import BaseTool from crewai.utilities.config import process_config from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified from crewai.utilities.converter import Converter, convert_to_model -from crewai.utilities.guardrail import process_guardrail +from crewai.utilities.guardrail import ( + GuardrailType, + GuardrailsType, + process_guardrail, +) from crewai.utilities.i18n import I18N from crewai.utilities.printer import Printer from crewai.utilities.string_utils import interpolate_only + _printer = Printer() @@ -150,10 +155,15 @@ class Task(BaseModel): default=None, ) processed_by_agents: set[str] = Field(default_factory=set) - guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field( + guardrail: GuardrailType = Field( default=None, description="Function or string description of a guardrail to validate task output before proceeding to next task", ) + guardrails: GuardrailsType = Field( + default=None, + description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task", + ) + max_retries: int | None = Field( default=None, description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0", @@ -234,6 +244,12 @@ class Task(BaseModel): return v _guardrail: Callable | None = PrivateAttr(default=None) + _guardrails: list[Callable[[TaskOutput], tuple[bool, Any]] | str] = PrivateAttr( + default=[] + ) + _guardrail_retry_counts: dict[int, int] = PrivateAttr( + default_factory=dict, + ) _original_description: str | None = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None) @@ -270,6 +286,50 @@ class Task(BaseModel): return self + @model_validator(mode="after") + def ensure_guardrails_is_list_of_callables(self) -> "Task": + guardrails = [] + if self.guardrails is not None: + if isinstance(self.guardrails, (list, tuple)): + if len(self.guardrails) > 0: + for guardrail in self.guardrails: + if callable(guardrail): + guardrails.append(guardrail) + elif isinstance(guardrail, str): + if self.agent is None: + raise ValueError( + "Agent is required to use non-programmatic guardrails" + ) + from crewai.tasks.llm_guardrail import LLMGuardrail + + guardrails.append( + LLMGuardrail(description=guardrail, llm=self.agent.llm) + ) + else: + raise ValueError("Guardrail must be a callable or a string") + else: + if callable(self.guardrails): + guardrails.append(self.guardrails) + elif isinstance(self.guardrails, str): + if self.agent is None: + raise ValueError( + "Agent is required to use non-programmatic guardrails" + ) + from crewai.tasks.llm_guardrail import LLMGuardrail + + guardrails.append( + LLMGuardrail(description=self.guardrails, llm=self.agent.llm) + ) + else: + raise ValueError("Guardrail must be a callable or a string") + + self._guardrails = guardrails + if self._guardrails: + self.guardrail = None + self._guardrail = None + + return self + @field_validator("id", mode="before") @classmethod def _deny_user_set_id(cls, v: UUID4 | None) -> None: @@ -458,48 +518,24 @@ class Task(BaseModel): output_format=self._get_output_format(), ) + if self._guardrails: + for idx, guardrail in enumerate(self._guardrails): + task_output = self._invoke_guardrail_function( + task_output=task_output, + agent=agent, + tools=tools, + guardrail=guardrail, + guardrail_index=idx, + ) + + # backwards support if self._guardrail: - guardrail_result = process_guardrail( - output=task_output, + task_output = self._invoke_guardrail_function( + task_output=task_output, + agent=agent, + tools=tools, guardrail=self._guardrail, - retry_count=self.retry_count, - event_source=self, - from_task=self, - from_agent=agent, ) - if not guardrail_result.success: - if self.retry_count >= self.guardrail_max_retries: - raise Exception( - f"Task failed guardrail validation after {self.guardrail_max_retries} retries. " - f"Last error: {guardrail_result.error}" - ) - - self.retry_count += 1 - context = self.i18n.errors("validation_error").format( - guardrail_result_error=guardrail_result.error, - task_output=task_output.raw, - ) - printer = Printer() - printer.print( - content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n", - color="yellow", - ) - return self._execute_core(agent, context, tools) - - if guardrail_result.result is None: - raise Exception( - "Task guardrail returned None as result. This is not allowed." - ) - - if isinstance(guardrail_result.result, str): - task_output.raw = guardrail_result.result - pydantic_output, json_output = self._export_output( - guardrail_result.result - ) - task_output.pydantic = pydantic_output - task_output.json_dict = json_output - elif isinstance(guardrail_result.result, TaskOutput): - task_output = guardrail_result.result self.output = task_output self.end_time = datetime.datetime.now() @@ -628,7 +664,10 @@ Follow these guidelines: try: crew_chat_messages = json.loads(crew_chat_messages_json) except json.JSONDecodeError as e: - _printer.print(f"An error occurred while parsing crew chat messages: {e}", color="red") + _printer.print( + f"An error occurred while parsing crew chat messages: {e}", + color="red", + ) raise conversation_history = "\n".join( @@ -791,3 +830,101 @@ Follow these guidelines: Fingerprint: The fingerprint of the task """ return self.security_config.fingerprint + + def _invoke_guardrail_function( + self, + task_output: TaskOutput, + agent: BaseAgent, + tools: list[BaseTool], + guardrail: Callable | None, + guardrail_index: int | None = None, + ) -> TaskOutput: + if not guardrail: + return task_output + + if guardrail_index is not None: + current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0) + else: + current_retry_count = self.retry_count + + max_attempts = self.guardrail_max_retries + 1 + + for attempt in range(max_attempts): + guardrail_result = process_guardrail( + output=task_output, + guardrail=guardrail, + retry_count=current_retry_count, + event_source=self, + from_task=self, + from_agent=agent, + ) + + if guardrail_result.success: + # Guardrail passed + if guardrail_result.result is None: + raise Exception( + "Task guardrail returned None as result. This is not allowed." + ) + + if isinstance(guardrail_result.result, str): + task_output.raw = guardrail_result.result + pydantic_output, json_output = self._export_output( + guardrail_result.result + ) + task_output.pydantic = pydantic_output + task_output.json_dict = json_output + elif isinstance(guardrail_result.result, TaskOutput): + task_output = guardrail_result.result + + return task_output + + # Guardrail failed + if attempt >= self.guardrail_max_retries: + # Max retries reached + guardrail_name = ( + f"guardrail {guardrail_index}" + if guardrail_index is not None + else "guardrail" + ) + raise Exception( + f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. " + f"Last error: {guardrail_result.error}" + ) + + if guardrail_index is not None: + current_retry_count += 1 + self._guardrail_retry_counts[guardrail_index] = current_retry_count + else: + self.retry_count += 1 + current_retry_count = self.retry_count + + context = self.i18n.errors("validation_error").format( + guardrail_result_error=guardrail_result.error, + task_output=task_output.raw, + ) + printer = Printer() + printer.print( + content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n", + color="yellow", + ) + + # Regenerate output from agent + result = agent.execute_task( + task=self, + context=context, + tools=tools, + ) + + pydantic_output, json_output = self._export_output(result) + task_output = TaskOutput( + name=self.name or self.description, + description=self.description, + expected_output=self.expected_output, + raw=result, + pydantic=pydantic_output, + json_dict=json_output, + agent=agent.role, + output_format=self._get_output_format(), + ) + + return task_output diff --git a/lib/crewai/src/crewai/utilities/guardrail.py b/lib/crewai/src/crewai/utilities/guardrail.py index e486abab8..53016d0b8 100644 --- a/lib/crewai/src/crewai/utilities/guardrail.py +++ b/lib/crewai/src/crewai/utilities/guardrail.py @@ -1,17 +1,22 @@ from __future__ import annotations -from collections.abc import Callable -from typing import TYPE_CHECKING, Any +from collections.abc import Callable, Sequence +from typing import TYPE_CHECKING, Any, TypeAlias from pydantic import BaseModel, Field, field_validator from typing_extensions import Self + if TYPE_CHECKING: from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.lite_agent import LiteAgent, LiteAgentOutput from crewai.task import Task from crewai.tasks.task_output import TaskOutput +GuardrailType: TypeAlias = Callable[["TaskOutput"], tuple[bool, Any]] | str | None + +GuardrailsType: TypeAlias = Sequence[GuardrailType] | GuardrailType + class GuardrailResult(BaseModel): """Result from a task guardrail execution. diff --git a/lib/crewai/tests/test_task_guardrails.py b/lib/crewai/tests/test_task_guardrails.py index 5a7edcca4..5930079c0 100644 --- a/lib/crewai/tests/test_task_guardrails.py +++ b/lib/crewai/tests/test_task_guardrails.py @@ -1,7 +1,7 @@ -import threading from unittest.mock import Mock, patch import pytest + from crewai import Agent, Task from crewai.events.event_bus import crewai_event_bus from crewai.events.event_types import ( @@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail from crewai.tasks.task_output import TaskOutput +def create_smart_task(**kwargs): + """ + Smart task factory that automatically assigns a mock agent when guardrails are present. + This maintains backward compatibility while handling the agent requirement for guardrails. + """ + guardrails_list = kwargs.get("guardrails") + has_guardrails = kwargs.get("guardrail") is not None or ( + guardrails_list is not None and len(guardrails_list) > 0 + ) + + if has_guardrails and kwargs.get("agent") is None: + kwargs["agent"] = Agent( + role="test_agent", goal="test_goal", backstory="test_backstory" + ) + + return Task(**kwargs) + + def test_task_without_guardrail(): """Test that tasks work normally without guardrails (backward compatibility).""" agent = Mock() @@ -21,7 +39,7 @@ def test_task_without_guardrail(): agent.execute_task.return_value = "test result" agent.crew = None - task = Task(description="Test task", expected_output="Output") + task = create_smart_task(description="Test task", expected_output="Output") result = task.execute_sync(agent=agent) assert isinstance(result, TaskOutput) @@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func(): agent.execute_task.return_value = "test result" agent.crew = None - task = Task(description="Test task", expected_output="Output", guardrail=guardrail) + task = create_smart_task( + description="Test task", expected_output="Output", guardrail=guardrail + ) result = task.execute_sync(agent=agent) assert isinstance(result, TaskOutput) @@ -57,7 +77,7 @@ def test_task_with_failing_guardrail(): agent.execute_task.side_effect = ["bad result", "good result"] agent.crew = None - task = Task( + task = create_smart_task( description="Test task", expected_output="Output", guardrail=guardrail, @@ -84,7 +104,7 @@ def test_task_with_guardrail_retries(): agent.execute_task.return_value = "bad result" agent.crew = None - task = Task( + task = create_smart_task( description="Test task", expected_output="Output", guardrail=guardrail, @@ -109,7 +129,7 @@ def test_guardrail_error_in_context(): agent.role = "test_agent" agent.crew = None - task = Task( + task = create_smart_task( description="Test task", expected_output="Output", guardrail=guardrail, @@ -176,92 +196,78 @@ def test_task_guardrail_process_output(task_output): def test_guardrail_emits_events(sample_agent): started_guardrail = [] completed_guardrail = [] - all_events_received = threading.Event() - expected_started = 3 # 2 from first task, 1 from second - expected_completed = 3 # 2 from first task, 1 from second - task1 = Task( + task = create_smart_task( description="Gather information about available books on the First World War", agent=sample_agent, expected_output="A list of available books on the First World War", guardrail="Ensure the authors are from Italy", ) - @crewai_event_bus.on(LLMGuardrailStartedEvent) - def handle_guardrail_started(source, event): - started_guardrail.append( - {"guardrail": event.guardrail, "retry_count": event.retry_count} - ) - if ( - len(started_guardrail) >= expected_started - and len(completed_guardrail) >= expected_completed - ): - all_events_received.set() + with crewai_event_bus.scoped_handlers(): - @crewai_event_bus.on(LLMGuardrailCompletedEvent) - def handle_guardrail_completed(source, event): - completed_guardrail.append( + @crewai_event_bus.on(LLMGuardrailStartedEvent) + def handle_guardrail_started(source, event): + assert source == task + started_guardrail.append( + {"guardrail": event.guardrail, "retry_count": event.retry_count} + ) + + @crewai_event_bus.on(LLMGuardrailCompletedEvent) + def handle_guardrail_completed(source, event): + assert source == task + completed_guardrail.append( + { + "success": event.success, + "result": event.result, + "error": event.error, + "retry_count": event.retry_count, + } + ) + + result = task.execute_sync(agent=sample_agent) + + def custom_guardrail(result: TaskOutput): + return (True, "good result from callable function") + + task = create_smart_task( + description="Test task", + expected_output="Output", + guardrail=custom_guardrail, + ) + + task.execute_sync(agent=sample_agent) + + expected_started_events = [ + {"guardrail": "Ensure the authors are from Italy", "retry_count": 0}, + {"guardrail": "Ensure the authors are from Italy", "retry_count": 1}, { - "success": event.success, - "result": event.result, - "error": event.error, - "retry_count": event.retry_count, - } - ) - if ( - len(started_guardrail) >= expected_started - and len(completed_guardrail) >= expected_completed - ): - all_events_received.set() + "guardrail": """def custom_guardrail(result: TaskOutput): + return (True, "good result from callable function")""", + "retry_count": 0, + }, + ] - result = task1.execute_sync(agent=sample_agent) - - def custom_guardrail(result: TaskOutput): - return (True, "good result from callable function") - - task2 = Task( - description="Test task", - expected_output="Output", - guardrail=custom_guardrail, - ) - - task2.execute_sync(agent=sample_agent) - - # Wait for all events to be received - assert all_events_received.wait(timeout=10), ( - "Timeout waiting for all guardrail events" - ) - - expected_started_events = [ - {"guardrail": "Ensure the authors are from Italy", "retry_count": 0}, - {"guardrail": "Ensure the authors are from Italy", "retry_count": 1}, - { - "guardrail": """def custom_guardrail(result: TaskOutput): - return (True, "good result from callable function")""", - "retry_count": 0, - }, - ] - - expected_completed_events = [ - { - "success": False, - "result": None, - "error": "The task result does not comply with the guardrail because none of " - "the listed authors are from Italy. All authors mentioned are from " - "different countries, including Germany, the UK, the USA, and others, " - "which violates the requirement that authors must be Italian.", - "retry_count": 0, - }, - {"success": True, "result": result.raw, "error": None, "retry_count": 1}, - { - "success": True, - "result": "good result from callable function", - "error": None, - "retry_count": 0, - }, - ] - assert started_guardrail == expected_started_events - assert completed_guardrail == expected_completed_events + expected_completed_events = [ + { + "success": False, + "result": None, + "error": "The task result does not comply with the guardrail because none of " + "the listed authors are from Italy. All authors mentioned are from " + "different countries, including Germany, the UK, the USA, and others, " + "which violates the requirement that authors must be Italian.", + "retry_count": 0, + }, + {"success": True, "result": result.raw, "error": None, "retry_count": 1}, + { + "success": True, + "result": "good result from callable function", + "error": None, + "retry_count": 0, + }, + ] + assert started_guardrail == expected_started_events + assert completed_guardrail == expected_completed_events @pytest.mark.vcr(filter_headers=["authorization"]) @@ -276,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output): match="Error while validating the task output: Unexpected error", ), ): - task = Task( + task = create_smart_task( description="Gather information about available books on the First World War", agent=sample_agent, expected_output="A list of available books on the First World War", @@ -298,7 +304,7 @@ def test_hallucination_guardrail_integration(): context="Test reference context for validation", llm=mock_llm, threshold=8.0 ) - task = Task( + task = create_smart_task( description="Test task with hallucination guardrail", expected_output="Valid output", guardrail=guardrail, @@ -318,3 +324,401 @@ def test_hallucination_guardrail_description_in_events(): event = LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=0) assert event.guardrail == "HallucinationGuardrail (no-op)" + + +def test_multiple_guardrails_sequential_processing(): + """Test that multiple guardrails are processed sequentially.""" + + def first_guardrail(result: TaskOutput) -> tuple[bool, str]: + """First guardrail adds prefix.""" + return (True, f"[FIRST] {result.raw}") + + def second_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Second guardrail adds suffix.""" + return (True, f"{result.raw} [SECOND]") + + def third_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Third guardrail converts to uppercase.""" + return (True, result.raw.upper()) + + agent = Mock() + agent.role = "sequential_agent" + agent.execute_task.return_value = "original text" + agent.crew = None + + task = create_smart_task( + description="Test sequential guardrails", + expected_output="Processed text", + guardrails=[first_guardrail, second_guardrail, third_guardrail], + ) + + result = task.execute_sync(agent=agent) + assert result.raw == "[FIRST] ORIGINAL TEXT [SECOND]" + + +def test_multiple_guardrails_with_validation_failure(): + """Test multiple guardrails where one fails validation.""" + + def length_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Ensure minimum length.""" + if len(result.raw) < 10: + return (False, "Text too short") + return (True, result.raw) + + def format_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Add formatting only if not already formatted.""" + if not result.raw.startswith("Formatted:"): + return (True, f"Formatted: {result.raw}") + return (True, result.raw) + + def validation_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Final validation.""" + if "Formatted:" not in result.raw: + return (False, "Missing formatting") + return (True, result.raw) + + # Use a callable that tracks calls and returns appropriate values + call_count = 0 + + def mock_execute_task(*args, **kwargs): + nonlocal call_count + call_count += 1 + result = ( + "short" + if call_count == 1 + else "this is a longer text that meets requirements" + ) + return result + + agent = Mock() + agent.role = "validation_agent" + agent.execute_task = mock_execute_task + agent.crew = None + + task = create_smart_task( + description="Test guardrails with validation", + expected_output="Valid formatted text", + guardrails=[length_guardrail, format_guardrail, validation_guardrail], + guardrail_max_retries=2, + ) + + result = task.execute_sync(agent=agent) + # The second call should be processed through all guardrails + assert result.raw == "Formatted: this is a longer text that meets requirements" + assert task._guardrail_retry_counts.get(0, 0) == 1 + + +def test_multiple_guardrails_with_mixed_string_and_taskoutput(): + """Test guardrails that return both strings and TaskOutput objects.""" + + def string_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Returns a string.""" + return (True, f"String: {result.raw}") + + def taskoutput_guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]: + """Returns a TaskOutput object.""" + new_output = TaskOutput( + name=result.name, + description=result.description, + expected_output=result.expected_output, + raw=f"TaskOutput: {result.raw}", + agent=result.agent, + output_format=result.output_format, + ) + return (True, new_output) + + def final_string_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Final string transformation.""" + return (True, f"Final: {result.raw}") + + agent = Mock() + agent.role = "mixed_agent" + agent.execute_task.return_value = "original" + agent.crew = None + + task = create_smart_task( + description="Test mixed return types", + expected_output="Mixed processing", + guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail], + ) + + result = task.execute_sync(agent=agent) + assert result.raw == "Final: TaskOutput: String: original" + + +def test_multiple_guardrails_with_retry_on_middle_guardrail(): + """Test that retry works correctly when a middle guardrail fails.""" + + call_count = {"first": 0, "second": 0, "third": 0} + + def first_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Always succeeds.""" + call_count["first"] += 1 + return (True, f"First({call_count['first']}): {result.raw}") + + def second_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Fails on first attempt, succeeds on second.""" + call_count["second"] += 1 + if call_count["second"] == 1: + return (False, "Second guardrail failed on first attempt") + return (True, f"Second({call_count['second']}): {result.raw}") + + def third_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Always succeeds.""" + call_count["third"] += 1 + return (True, f"Third({call_count['third']}): {result.raw}") + + agent = Mock() + agent.role = "retry_agent" + agent.execute_task.return_value = "base" + agent.crew = None + + task = create_smart_task( + description="Test retry in middle guardrail", + expected_output="Retry handling", + guardrails=[first_guardrail, second_guardrail, third_guardrail], + guardrail_max_retries=2, + ) + + result = task.execute_sync(agent=agent) + assert task._guardrail_retry_counts.get(1, 0) == 1 + assert call_count["first"] == 1 + assert call_count["second"] == 2 + assert call_count["third"] == 1 + assert "Second(2)" in result.raw + + +def test_multiple_guardrails_with_max_retries_exceeded(): + """Test that exception is raised when max retries exceeded with multiple guardrails.""" + + def passing_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Always passes.""" + return (True, f"Passed: {result.raw}") + + def failing_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Always fails.""" + return (False, "This guardrail always fails") + + agent = Mock() + agent.role = "failing_agent" + agent.execute_task.return_value = "test" + agent.crew = None + + task = create_smart_task( + description="Test max retries with multiple guardrails", + expected_output="Will fail", + guardrails=[passing_guardrail, failing_guardrail], + guardrail_max_retries=1, + ) + + with pytest.raises(Exception) as exc_info: + task.execute_sync(agent=agent) + + assert "Task failed guardrail 1 validation after 1 retries" in str(exc_info.value) + assert "This guardrail always fails" in str(exc_info.value) + assert task._guardrail_retry_counts.get(1, 0) == 1 + + +def test_multiple_guardrails_empty_list(): + """Test that empty guardrails list works correctly.""" + + agent = Mock() + agent.role = "empty_agent" + agent.execute_task.return_value = "no guardrails" + agent.crew = None + + task = create_smart_task( + description="Test empty guardrails list", + expected_output="No processing", + guardrails=[], + ) + + result = task.execute_sync(agent=agent) + assert result.raw == "no guardrails" + + +def test_multiple_guardrails_with_llm_guardrails(): + """Test mixing callable and LLM guardrails.""" + + def callable_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Callable guardrail.""" + return (True, f"Callable: {result.raw}") + + # Create a proper mock agent without config issues + from crewai import Agent + + agent = Agent( + role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory" + ) + + task = create_smart_task( + description="Test mixed guardrail types", + expected_output="Mixed processing", + guardrails=[callable_guardrail, "Ensure the output is professional"], + agent=agent, + ) + + # The LLM guardrail will be converted to LLMGuardrail internally + assert len(task._guardrails) == 2 + assert callable(task._guardrails[0]) + assert callable(task._guardrails[1]) # LLMGuardrail is callable + + +def test_multiple_guardrails_processing_order(): + """Test that guardrails are processed in the correct order.""" + + processing_order = [] + + def first_guardrail(result: TaskOutput) -> tuple[bool, str]: + processing_order.append("first") + return (True, f"1-{result.raw}") + + def second_guardrail(result: TaskOutput) -> tuple[bool, str]: + processing_order.append("second") + return (True, f"2-{result.raw}") + + def third_guardrail(result: TaskOutput) -> tuple[bool, str]: + processing_order.append("third") + return (True, f"3-{result.raw}") + + agent = Mock() + agent.role = "order_agent" + agent.execute_task.return_value = "base" + agent.crew = None + + task = create_smart_task( + description="Test processing order", + expected_output="Ordered processing", + guardrails=[first_guardrail, second_guardrail, third_guardrail], + ) + + result = task.execute_sync(agent=agent) + assert processing_order == ["first", "second", "third"] + assert result.raw == "3-2-1-base" + + +def test_multiple_guardrails_with_pydantic_output(): + """Test multiple guardrails with Pydantic output model.""" + from pydantic import BaseModel, Field + + class TestModel(BaseModel): + content: str = Field(description="The content") + processed: bool = Field(description="Whether it was processed") + + def json_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Convert to JSON format.""" + import json + + data = {"content": result.raw, "processed": True} + return (True, json.dumps(data)) + + def validation_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Validate JSON structure.""" + import json + + try: + data = json.loads(result.raw) + if "content" not in data or "processed" not in data: + return (False, "Missing required fields") + return (True, result.raw) + except json.JSONDecodeError: + return (False, "Invalid JSON format") + + agent = Mock() + agent.role = "pydantic_agent" + agent.execute_task.return_value = "test content" + agent.crew = None + + task = create_smart_task( + description="Test guardrails with Pydantic", + expected_output="Structured output", + guardrails=[json_guardrail, validation_guardrail], + output_pydantic=TestModel, + ) + + result = task.execute_sync(agent=agent) + + # Verify the result is valid JSON and can be parsed + import json + + parsed = json.loads(result.raw) + assert parsed["content"] == "test content" + assert parsed["processed"] is True + + +def test_guardrails_vs_single_guardrail_mutual_exclusion(): + """Test that guardrails list nullifies single guardrail.""" + + def single_guardrail(result: TaskOutput) -> tuple[bool, str]: + """Single guardrail - should be ignored.""" + return (True, f"Single: {result.raw}") + + def list_guardrail(result: TaskOutput) -> tuple[bool, str]: + """List guardrail - should be used.""" + return (True, f"List: {result.raw}") + + agent = Mock() + agent.role = "exclusion_agent" + agent.execute_task.return_value = "test" + agent.crew = None + + task = create_smart_task( + description="Test mutual exclusion", + expected_output="Exclusion test", + guardrail=single_guardrail, # This should be ignored + guardrails=[list_guardrail], # This should be used + ) + + result = task.execute_sync(agent=agent) + # Should only use the guardrails list, not the single guardrail + assert result.raw == "List: test" + assert task._guardrail is None # Single guardrail should be nullified + + +def test_per_guardrail_independent_retry_tracking(): + """Test that each guardrail has independent retry tracking.""" + + call_counts = {"g1": 0, "g2": 0, "g3": 0} + + def guardrail_1(result: TaskOutput) -> tuple[bool, str]: + """Fails twice, then succeeds.""" + call_counts["g1"] += 1 + if call_counts["g1"] <= 2: + return (False, "Guardrail 1 not ready yet") + return (True, f"G1({call_counts['g1']}): {result.raw}") + + def guardrail_2(result: TaskOutput) -> tuple[bool, str]: + """Fails once, then succeeds.""" + call_counts["g2"] += 1 + if call_counts["g2"] == 1: + return (False, "Guardrail 2 not ready yet") + return (True, f"G2({call_counts['g2']}): {result.raw}") + + def guardrail_3(result: TaskOutput) -> tuple[bool, str]: + """Always succeeds.""" + call_counts["g3"] += 1 + return (True, f"G3({call_counts['g3']}): {result.raw}") + + agent = Mock() + agent.role = "independent_retry_agent" + agent.execute_task.return_value = "base" + agent.crew = None + + task = create_smart_task( + description="Test independent retry tracking", + expected_output="Independent retries", + guardrails=[guardrail_1, guardrail_2, guardrail_3], + guardrail_max_retries=3, + ) + + result = task.execute_sync(agent=agent) + + assert task._guardrail_retry_counts.get(0, 0) == 2 + assert task._guardrail_retry_counts.get(1, 0) == 1 + assert task._guardrail_retry_counts.get(2, 0) == 0 + + assert call_counts["g1"] == 3 + assert call_counts["g2"] == 2 + assert call_counts["g3"] == 1 + + assert "G3(1)" in result.raw