Enhance Task Class Guardrail Validation Logic

- Updated the Task class to ensure that guardrails can only be used when an agent is provided, improving error handling and validation.
- Introduced a smart task factory function to automatically assign a mock agent when guardrails are present, maintaining backward compatibility.
- Updated tests to utilize the new smart task factory, ensuring proper functionality with and without guardrails.

This update enhances the robustness of task execution and guardrail integration, ensuring better control over task validation outcomes.
This commit is contained in:
lorenzejay
2025-10-15 14:02:55 -07:00
parent d40846e6af
commit 594c21a36e
2 changed files with 44 additions and 22 deletions

View File

@@ -280,10 +280,12 @@ class Task(BaseModel):
@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = []
if self.agent is None:
raise ValueError("Agent is required to use guardrails")
if self.guardrails is not None and (
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
):
if self.agent is None:
raise ValueError("Agent is required to use guardrails")
if self.guardrails is not None:
if callable(self.guardrails):
guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str):

View File

@@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
def create_smart_task(**kwargs):
"""
Smart task factory that automatically assigns a mock agent when guardrails are present.
This maintains backward compatibility while handling the agent requirement for guardrails.
"""
guardrails_list = kwargs.get("guardrails")
has_guardrails = kwargs.get("guardrail") is not None or (
guardrails_list is not None and len(guardrails_list) > 0
)
if has_guardrails and kwargs.get("agent") is None:
kwargs["agent"] = Agent(
role="test_agent", goal="test_goal", backstory="test_backstory"
)
return Task(**kwargs)
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
@@ -21,7 +39,7 @@ def test_task_without_guardrail():
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output")
task = create_smart_task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func():
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
task = create_smart_task(
description="Test task", expected_output="Output", guardrail=guardrail
)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -57,7 +77,7 @@ def test_task_with_failing_guardrail():
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -84,7 +104,7 @@ def test_task_with_guardrail_retries():
agent.execute_task.return_value = "bad result"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -109,7 +129,7 @@ def test_guardrail_error_in_context():
agent.role = "test_agent"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -177,7 +197,7 @@ def test_guardrail_emits_events(sample_agent):
started_guardrail = []
completed_guardrail = []
task = Task(
task = create_smart_task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -210,7 +230,7 @@ def test_guardrail_emits_events(sample_agent):
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = Task(
task = create_smart_task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
@@ -262,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
match="Error while validating the task output: Unexpected error",
),
):
task = Task(
task = create_smart_task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -284,7 +304,7 @@ def test_hallucination_guardrail_integration():
context="Test reference context for validation", llm=mock_llm, threshold=8.0
)
task = Task(
task = create_smart_task(
description="Test task with hallucination guardrail",
expected_output="Valid output",
guardrail=guardrail,
@@ -326,7 +346,7 @@ def test_multiple_guardrails_sequential_processing():
agent.execute_task.return_value = "original text"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test sequential guardrails",
expected_output="Processed text",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -375,7 +395,7 @@ def test_multiple_guardrails_with_validation_failure():
agent.execute_task = mock_execute_task
agent.crew = None
task = Task(
task = create_smart_task(
description="Test guardrails with validation",
expected_output="Valid formatted text",
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
@@ -416,7 +436,7 @@ def test_multiple_guardrails_with_mixed_string_and_taskoutput():
agent.execute_task.return_value = "original"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test mixed return types",
expected_output="Mixed processing",
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
@@ -453,7 +473,7 @@ def test_multiple_guardrails_with_retry_on_middle_guardrail():
agent.execute_task.return_value = "base"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test retry in middle guardrail",
expected_output="Retry handling",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -484,7 +504,7 @@ def test_multiple_guardrails_with_max_retries_exceeded():
agent.execute_task.return_value = "test"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test max retries with multiple guardrails",
expected_output="Will fail",
guardrails=[passing_guardrail, failing_guardrail],
@@ -507,7 +527,7 @@ def test_multiple_guardrails_empty_list():
agent.execute_task.return_value = "no guardrails"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test empty guardrails list",
expected_output="No processing",
guardrails=[],
@@ -531,7 +551,7 @@ def test_multiple_guardrails_with_llm_guardrails():
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
)
task = Task(
task = create_smart_task(
description="Test mixed guardrail types",
expected_output="Mixed processing",
guardrails=[callable_guardrail, "Ensure the output is professional"],
@@ -566,7 +586,7 @@ def test_multiple_guardrails_processing_order():
agent.execute_task.return_value = "base"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test processing order",
expected_output="Ordered processing",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -609,7 +629,7 @@ def test_multiple_guardrails_with_pydantic_output():
agent.execute_task.return_value = "test content"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test guardrails with Pydantic",
expected_output="Structured output",
guardrails=[json_guardrail, validation_guardrail],
@@ -642,7 +662,7 @@ def test_guardrails_vs_single_guardrail_mutual_exclusion():
agent.execute_task.return_value = "test"
agent.crew = None
task = Task(
task = create_smart_task(
description="Test mutual exclusion",
expected_output="Exclusion test",
guardrail=single_guardrail, # This should be ignored