mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 04:18:35 +00:00
Enhance Task Class Guardrail Validation Logic
- Updated the Task class to ensure that guardrails can only be used when an agent is provided, improving error handling and validation. - Introduced a smart task factory function to automatically assign a mock agent when guardrails are present, maintaining backward compatibility. - Updated tests to utilize the new smart task factory, ensuring proper functionality with and without guardrails. This update enhances the robustness of task execution and guardrail integration, ensuring better control over task validation outcomes.
This commit is contained in:
@@ -280,10 +280,12 @@ class Task(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrails_is_list_of_callables(self) -> "Task":
|
||||
guardrails = []
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is required to use guardrails")
|
||||
if self.guardrails is not None and (
|
||||
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
|
||||
):
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is required to use guardrails")
|
||||
|
||||
if self.guardrails is not None:
|
||||
if callable(self.guardrails):
|
||||
guardrails.append(self.guardrails)
|
||||
elif isinstance(self.guardrails, str):
|
||||
|
||||
@@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
def create_smart_task(**kwargs):
|
||||
"""
|
||||
Smart task factory that automatically assigns a mock agent when guardrails are present.
|
||||
This maintains backward compatibility while handling the agent requirement for guardrails.
|
||||
"""
|
||||
guardrails_list = kwargs.get("guardrails")
|
||||
has_guardrails = kwargs.get("guardrail") is not None or (
|
||||
guardrails_list is not None and len(guardrails_list) > 0
|
||||
)
|
||||
|
||||
if has_guardrails and kwargs.get("agent") is None:
|
||||
kwargs["agent"] = Agent(
|
||||
role="test_agent", goal="test_goal", backstory="test_backstory"
|
||||
)
|
||||
|
||||
return Task(**kwargs)
|
||||
|
||||
|
||||
def test_task_without_guardrail():
|
||||
"""Test that tasks work normally without guardrails (backward compatibility)."""
|
||||
agent = Mock()
|
||||
@@ -21,7 +39,7 @@ def test_task_without_guardrail():
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(description="Test task", expected_output="Output")
|
||||
task = create_smart_task(description="Test task", expected_output="Output")
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
@@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func():
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
task = create_smart_task(
|
||||
description="Test task", expected_output="Output", guardrail=guardrail
|
||||
)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
@@ -57,7 +77,7 @@ def test_task_with_failing_guardrail():
|
||||
agent.execute_task.side_effect = ["bad result", "good result"]
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -84,7 +104,7 @@ def test_task_with_guardrail_retries():
|
||||
agent.execute_task.return_value = "bad result"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -109,7 +129,7 @@ def test_guardrail_error_in_context():
|
||||
agent.role = "test_agent"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=guardrail,
|
||||
@@ -177,7 +197,7 @@ def test_guardrail_emits_events(sample_agent):
|
||||
started_guardrail = []
|
||||
completed_guardrail = []
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
@@ -210,7 +230,7 @@ def test_guardrail_emits_events(sample_agent):
|
||||
def custom_guardrail(result: TaskOutput):
|
||||
return (True, "good result from callable function")
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail=custom_guardrail,
|
||||
@@ -262,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
|
||||
match="Error while validating the task output: Unexpected error",
|
||||
),
|
||||
):
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Gather information about available books on the First World War",
|
||||
agent=sample_agent,
|
||||
expected_output="A list of available books on the First World War",
|
||||
@@ -284,7 +304,7 @@ def test_hallucination_guardrail_integration():
|
||||
context="Test reference context for validation", llm=mock_llm, threshold=8.0
|
||||
)
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test task with hallucination guardrail",
|
||||
expected_output="Valid output",
|
||||
guardrail=guardrail,
|
||||
@@ -326,7 +346,7 @@ def test_multiple_guardrails_sequential_processing():
|
||||
agent.execute_task.return_value = "original text"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test sequential guardrails",
|
||||
expected_output="Processed text",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
@@ -375,7 +395,7 @@ def test_multiple_guardrails_with_validation_failure():
|
||||
agent.execute_task = mock_execute_task
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test guardrails with validation",
|
||||
expected_output="Valid formatted text",
|
||||
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
|
||||
@@ -416,7 +436,7 @@ def test_multiple_guardrails_with_mixed_string_and_taskoutput():
|
||||
agent.execute_task.return_value = "original"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test mixed return types",
|
||||
expected_output="Mixed processing",
|
||||
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
|
||||
@@ -453,7 +473,7 @@ def test_multiple_guardrails_with_retry_on_middle_guardrail():
|
||||
agent.execute_task.return_value = "base"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test retry in middle guardrail",
|
||||
expected_output="Retry handling",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
@@ -484,7 +504,7 @@ def test_multiple_guardrails_with_max_retries_exceeded():
|
||||
agent.execute_task.return_value = "test"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test max retries with multiple guardrails",
|
||||
expected_output="Will fail",
|
||||
guardrails=[passing_guardrail, failing_guardrail],
|
||||
@@ -507,7 +527,7 @@ def test_multiple_guardrails_empty_list():
|
||||
agent.execute_task.return_value = "no guardrails"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test empty guardrails list",
|
||||
expected_output="No processing",
|
||||
guardrails=[],
|
||||
@@ -531,7 +551,7 @@ def test_multiple_guardrails_with_llm_guardrails():
|
||||
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
|
||||
)
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test mixed guardrail types",
|
||||
expected_output="Mixed processing",
|
||||
guardrails=[callable_guardrail, "Ensure the output is professional"],
|
||||
@@ -566,7 +586,7 @@ def test_multiple_guardrails_processing_order():
|
||||
agent.execute_task.return_value = "base"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test processing order",
|
||||
expected_output="Ordered processing",
|
||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||
@@ -609,7 +629,7 @@ def test_multiple_guardrails_with_pydantic_output():
|
||||
agent.execute_task.return_value = "test content"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test guardrails with Pydantic",
|
||||
expected_output="Structured output",
|
||||
guardrails=[json_guardrail, validation_guardrail],
|
||||
@@ -642,7 +662,7 @@ def test_guardrails_vs_single_guardrail_mutual_exclusion():
|
||||
agent.execute_task.return_value = "test"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(
|
||||
task = create_smart_task(
|
||||
description="Test mutual exclusion",
|
||||
expected_output="Exclusion test",
|
||||
guardrail=single_guardrail, # This should be ignored
|
||||
|
||||
Reference in New Issue
Block a user