Enhance Task Class Guardrail Validation Logic

- Updated the Task class to ensure that guardrails can only be used when an agent is provided, improving error handling and validation.
- Introduced a smart task factory function to automatically assign a mock agent when guardrails are present, maintaining backward compatibility.
- Updated tests to utilize the new smart task factory, ensuring proper functionality with and without guardrails.

This update enhances the robustness of task execution and guardrail integration, ensuring better control over task validation outcomes.
This commit is contained in:
lorenzejay
2025-10-15 14:02:55 -07:00
parent d40846e6af
commit 594c21a36e
2 changed files with 44 additions and 22 deletions

View File

@@ -280,10 +280,12 @@ class Task(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task": def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = [] guardrails = []
if self.agent is None: if self.guardrails is not None and (
raise ValueError("Agent is required to use guardrails") not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
):
if self.agent is None:
raise ValueError("Agent is required to use guardrails")
if self.guardrails is not None:
if callable(self.guardrails): if callable(self.guardrails):
guardrails.append(self.guardrails) guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str): elif isinstance(self.guardrails, str):

View File

@@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
def create_smart_task(**kwargs):
"""
Smart task factory that automatically assigns a mock agent when guardrails are present.
This maintains backward compatibility while handling the agent requirement for guardrails.
"""
guardrails_list = kwargs.get("guardrails")
has_guardrails = kwargs.get("guardrail") is not None or (
guardrails_list is not None and len(guardrails_list) > 0
)
if has_guardrails and kwargs.get("agent") is None:
kwargs["agent"] = Agent(
role="test_agent", goal="test_goal", backstory="test_backstory"
)
return Task(**kwargs)
def test_task_without_guardrail(): def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility).""" """Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock() agent = Mock()
@@ -21,7 +39,7 @@ def test_task_without_guardrail():
agent.execute_task.return_value = "test result" agent.execute_task.return_value = "test result"
agent.crew = None agent.crew = None
task = Task(description="Test task", expected_output="Output") task = create_smart_task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
@@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func():
agent.execute_task.return_value = "test result" agent.execute_task.return_value = "test result"
agent.crew = None agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail) task = create_smart_task(
description="Test task", expected_output="Output", guardrail=guardrail
)
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
@@ -57,7 +77,7 @@ def test_task_with_failing_guardrail():
agent.execute_task.side_effect = ["bad result", "good result"] agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=guardrail, guardrail=guardrail,
@@ -84,7 +104,7 @@ def test_task_with_guardrail_retries():
agent.execute_task.return_value = "bad result" agent.execute_task.return_value = "bad result"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=guardrail, guardrail=guardrail,
@@ -109,7 +129,7 @@ def test_guardrail_error_in_context():
agent.role = "test_agent" agent.role = "test_agent"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=guardrail, guardrail=guardrail,
@@ -177,7 +197,7 @@ def test_guardrail_emits_events(sample_agent):
started_guardrail = [] started_guardrail = []
completed_guardrail = [] completed_guardrail = []
task = Task( task = create_smart_task(
description="Gather information about available books on the First World War", description="Gather information about available books on the First World War",
agent=sample_agent, agent=sample_agent,
expected_output="A list of available books on the First World War", expected_output="A list of available books on the First World War",
@@ -210,7 +230,7 @@ def test_guardrail_emits_events(sample_agent):
def custom_guardrail(result: TaskOutput): def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function") return (True, "good result from callable function")
task = Task( task = create_smart_task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=custom_guardrail, guardrail=custom_guardrail,
@@ -262,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
match="Error while validating the task output: Unexpected error", match="Error while validating the task output: Unexpected error",
), ),
): ):
task = Task( task = create_smart_task(
description="Gather information about available books on the First World War", description="Gather information about available books on the First World War",
agent=sample_agent, agent=sample_agent,
expected_output="A list of available books on the First World War", expected_output="A list of available books on the First World War",
@@ -284,7 +304,7 @@ def test_hallucination_guardrail_integration():
context="Test reference context for validation", llm=mock_llm, threshold=8.0 context="Test reference context for validation", llm=mock_llm, threshold=8.0
) )
task = Task( task = create_smart_task(
description="Test task with hallucination guardrail", description="Test task with hallucination guardrail",
expected_output="Valid output", expected_output="Valid output",
guardrail=guardrail, guardrail=guardrail,
@@ -326,7 +346,7 @@ def test_multiple_guardrails_sequential_processing():
agent.execute_task.return_value = "original text" agent.execute_task.return_value = "original text"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test sequential guardrails", description="Test sequential guardrails",
expected_output="Processed text", expected_output="Processed text",
guardrails=[first_guardrail, second_guardrail, third_guardrail], guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -375,7 +395,7 @@ def test_multiple_guardrails_with_validation_failure():
agent.execute_task = mock_execute_task agent.execute_task = mock_execute_task
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test guardrails with validation", description="Test guardrails with validation",
expected_output="Valid formatted text", expected_output="Valid formatted text",
guardrails=[length_guardrail, format_guardrail, validation_guardrail], guardrails=[length_guardrail, format_guardrail, validation_guardrail],
@@ -416,7 +436,7 @@ def test_multiple_guardrails_with_mixed_string_and_taskoutput():
agent.execute_task.return_value = "original" agent.execute_task.return_value = "original"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test mixed return types", description="Test mixed return types",
expected_output="Mixed processing", expected_output="Mixed processing",
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail], guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
@@ -453,7 +473,7 @@ def test_multiple_guardrails_with_retry_on_middle_guardrail():
agent.execute_task.return_value = "base" agent.execute_task.return_value = "base"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test retry in middle guardrail", description="Test retry in middle guardrail",
expected_output="Retry handling", expected_output="Retry handling",
guardrails=[first_guardrail, second_guardrail, third_guardrail], guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -484,7 +504,7 @@ def test_multiple_guardrails_with_max_retries_exceeded():
agent.execute_task.return_value = "test" agent.execute_task.return_value = "test"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test max retries with multiple guardrails", description="Test max retries with multiple guardrails",
expected_output="Will fail", expected_output="Will fail",
guardrails=[passing_guardrail, failing_guardrail], guardrails=[passing_guardrail, failing_guardrail],
@@ -507,7 +527,7 @@ def test_multiple_guardrails_empty_list():
agent.execute_task.return_value = "no guardrails" agent.execute_task.return_value = "no guardrails"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test empty guardrails list", description="Test empty guardrails list",
expected_output="No processing", expected_output="No processing",
guardrails=[], guardrails=[],
@@ -531,7 +551,7 @@ def test_multiple_guardrails_with_llm_guardrails():
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory" role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
) )
task = Task( task = create_smart_task(
description="Test mixed guardrail types", description="Test mixed guardrail types",
expected_output="Mixed processing", expected_output="Mixed processing",
guardrails=[callable_guardrail, "Ensure the output is professional"], guardrails=[callable_guardrail, "Ensure the output is professional"],
@@ -566,7 +586,7 @@ def test_multiple_guardrails_processing_order():
agent.execute_task.return_value = "base" agent.execute_task.return_value = "base"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test processing order", description="Test processing order",
expected_output="Ordered processing", expected_output="Ordered processing",
guardrails=[first_guardrail, second_guardrail, third_guardrail], guardrails=[first_guardrail, second_guardrail, third_guardrail],
@@ -609,7 +629,7 @@ def test_multiple_guardrails_with_pydantic_output():
agent.execute_task.return_value = "test content" agent.execute_task.return_value = "test content"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test guardrails with Pydantic", description="Test guardrails with Pydantic",
expected_output="Structured output", expected_output="Structured output",
guardrails=[json_guardrail, validation_guardrail], guardrails=[json_guardrail, validation_guardrail],
@@ -642,7 +662,7 @@ def test_guardrails_vs_single_guardrail_mutual_exclusion():
agent.execute_task.return_value = "test" agent.execute_task.return_value = "test"
agent.crew = None agent.crew = None
task = Task( task = create_smart_task(
description="Test mutual exclusion", description="Test mutual exclusion",
expected_output="Exclusion test", expected_output="Exclusion test",
guardrail=single_guardrail, # This should be ignored guardrail=single_guardrail, # This should be ignored