mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Enhance Task Class Guardrail Validation Logic
- Updated the Task class to ensure that guardrails can only be used when an agent is provided, improving error handling and validation. - Introduced a smart task factory function to automatically assign a mock agent when guardrails are present, maintaining backward compatibility. - Updated tests to utilize the new smart task factory, ensuring proper functionality with and without guardrails. This update enhances the robustness of task execution and guardrail integration, ensuring better control over task validation outcomes.
This commit is contained in:
@@ -280,10 +280,12 @@ class Task(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def ensure_guardrails_is_list_of_callables(self) -> "Task":
|
def ensure_guardrails_is_list_of_callables(self) -> "Task":
|
||||||
guardrails = []
|
guardrails = []
|
||||||
if self.agent is None:
|
if self.guardrails is not None and (
|
||||||
raise ValueError("Agent is required to use guardrails")
|
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
|
||||||
|
):
|
||||||
|
if self.agent is None:
|
||||||
|
raise ValueError("Agent is required to use guardrails")
|
||||||
|
|
||||||
if self.guardrails is not None:
|
|
||||||
if callable(self.guardrails):
|
if callable(self.guardrails):
|
||||||
guardrails.append(self.guardrails)
|
guardrails.append(self.guardrails)
|
||||||
elif isinstance(self.guardrails, str):
|
elif isinstance(self.guardrails, str):
|
||||||
|
|||||||
@@ -14,6 +14,24 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
|
|||||||
from crewai.tasks.task_output import TaskOutput
|
from crewai.tasks.task_output import TaskOutput
|
||||||
|
|
||||||
|
|
||||||
|
def create_smart_task(**kwargs):
|
||||||
|
"""
|
||||||
|
Smart task factory that automatically assigns a mock agent when guardrails are present.
|
||||||
|
This maintains backward compatibility while handling the agent requirement for guardrails.
|
||||||
|
"""
|
||||||
|
guardrails_list = kwargs.get("guardrails")
|
||||||
|
has_guardrails = kwargs.get("guardrail") is not None or (
|
||||||
|
guardrails_list is not None and len(guardrails_list) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if has_guardrails and kwargs.get("agent") is None:
|
||||||
|
kwargs["agent"] = Agent(
|
||||||
|
role="test_agent", goal="test_goal", backstory="test_backstory"
|
||||||
|
)
|
||||||
|
|
||||||
|
return Task(**kwargs)
|
||||||
|
|
||||||
|
|
||||||
def test_task_without_guardrail():
|
def test_task_without_guardrail():
|
||||||
"""Test that tasks work normally without guardrails (backward compatibility)."""
|
"""Test that tasks work normally without guardrails (backward compatibility)."""
|
||||||
agent = Mock()
|
agent = Mock()
|
||||||
@@ -21,7 +39,7 @@ def test_task_without_guardrail():
|
|||||||
agent.execute_task.return_value = "test result"
|
agent.execute_task.return_value = "test result"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(description="Test task", expected_output="Output")
|
task = create_smart_task(description="Test task", expected_output="Output")
|
||||||
|
|
||||||
result = task.execute_sync(agent=agent)
|
result = task.execute_sync(agent=agent)
|
||||||
assert isinstance(result, TaskOutput)
|
assert isinstance(result, TaskOutput)
|
||||||
@@ -39,7 +57,9 @@ def test_task_with_successful_guardrail_func():
|
|||||||
agent.execute_task.return_value = "test result"
|
agent.execute_task.return_value = "test result"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
task = create_smart_task(
|
||||||
|
description="Test task", expected_output="Output", guardrail=guardrail
|
||||||
|
)
|
||||||
|
|
||||||
result = task.execute_sync(agent=agent)
|
result = task.execute_sync(agent=agent)
|
||||||
assert isinstance(result, TaskOutput)
|
assert isinstance(result, TaskOutput)
|
||||||
@@ -57,7 +77,7 @@ def test_task_with_failing_guardrail():
|
|||||||
agent.execute_task.side_effect = ["bad result", "good result"]
|
agent.execute_task.side_effect = ["bad result", "good result"]
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test task",
|
description="Test task",
|
||||||
expected_output="Output",
|
expected_output="Output",
|
||||||
guardrail=guardrail,
|
guardrail=guardrail,
|
||||||
@@ -84,7 +104,7 @@ def test_task_with_guardrail_retries():
|
|||||||
agent.execute_task.return_value = "bad result"
|
agent.execute_task.return_value = "bad result"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test task",
|
description="Test task",
|
||||||
expected_output="Output",
|
expected_output="Output",
|
||||||
guardrail=guardrail,
|
guardrail=guardrail,
|
||||||
@@ -109,7 +129,7 @@ def test_guardrail_error_in_context():
|
|||||||
agent.role = "test_agent"
|
agent.role = "test_agent"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test task",
|
description="Test task",
|
||||||
expected_output="Output",
|
expected_output="Output",
|
||||||
guardrail=guardrail,
|
guardrail=guardrail,
|
||||||
@@ -177,7 +197,7 @@ def test_guardrail_emits_events(sample_agent):
|
|||||||
started_guardrail = []
|
started_guardrail = []
|
||||||
completed_guardrail = []
|
completed_guardrail = []
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Gather information about available books on the First World War",
|
description="Gather information about available books on the First World War",
|
||||||
agent=sample_agent,
|
agent=sample_agent,
|
||||||
expected_output="A list of available books on the First World War",
|
expected_output="A list of available books on the First World War",
|
||||||
@@ -210,7 +230,7 @@ def test_guardrail_emits_events(sample_agent):
|
|||||||
def custom_guardrail(result: TaskOutput):
|
def custom_guardrail(result: TaskOutput):
|
||||||
return (True, "good result from callable function")
|
return (True, "good result from callable function")
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test task",
|
description="Test task",
|
||||||
expected_output="Output",
|
expected_output="Output",
|
||||||
guardrail=custom_guardrail,
|
guardrail=custom_guardrail,
|
||||||
@@ -262,7 +282,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
|
|||||||
match="Error while validating the task output: Unexpected error",
|
match="Error while validating the task output: Unexpected error",
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Gather information about available books on the First World War",
|
description="Gather information about available books on the First World War",
|
||||||
agent=sample_agent,
|
agent=sample_agent,
|
||||||
expected_output="A list of available books on the First World War",
|
expected_output="A list of available books on the First World War",
|
||||||
@@ -284,7 +304,7 @@ def test_hallucination_guardrail_integration():
|
|||||||
context="Test reference context for validation", llm=mock_llm, threshold=8.0
|
context="Test reference context for validation", llm=mock_llm, threshold=8.0
|
||||||
)
|
)
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test task with hallucination guardrail",
|
description="Test task with hallucination guardrail",
|
||||||
expected_output="Valid output",
|
expected_output="Valid output",
|
||||||
guardrail=guardrail,
|
guardrail=guardrail,
|
||||||
@@ -326,7 +346,7 @@ def test_multiple_guardrails_sequential_processing():
|
|||||||
agent.execute_task.return_value = "original text"
|
agent.execute_task.return_value = "original text"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test sequential guardrails",
|
description="Test sequential guardrails",
|
||||||
expected_output="Processed text",
|
expected_output="Processed text",
|
||||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||||
@@ -375,7 +395,7 @@ def test_multiple_guardrails_with_validation_failure():
|
|||||||
agent.execute_task = mock_execute_task
|
agent.execute_task = mock_execute_task
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test guardrails with validation",
|
description="Test guardrails with validation",
|
||||||
expected_output="Valid formatted text",
|
expected_output="Valid formatted text",
|
||||||
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
|
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
|
||||||
@@ -416,7 +436,7 @@ def test_multiple_guardrails_with_mixed_string_and_taskoutput():
|
|||||||
agent.execute_task.return_value = "original"
|
agent.execute_task.return_value = "original"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test mixed return types",
|
description="Test mixed return types",
|
||||||
expected_output="Mixed processing",
|
expected_output="Mixed processing",
|
||||||
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
|
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
|
||||||
@@ -453,7 +473,7 @@ def test_multiple_guardrails_with_retry_on_middle_guardrail():
|
|||||||
agent.execute_task.return_value = "base"
|
agent.execute_task.return_value = "base"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test retry in middle guardrail",
|
description="Test retry in middle guardrail",
|
||||||
expected_output="Retry handling",
|
expected_output="Retry handling",
|
||||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||||
@@ -484,7 +504,7 @@ def test_multiple_guardrails_with_max_retries_exceeded():
|
|||||||
agent.execute_task.return_value = "test"
|
agent.execute_task.return_value = "test"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test max retries with multiple guardrails",
|
description="Test max retries with multiple guardrails",
|
||||||
expected_output="Will fail",
|
expected_output="Will fail",
|
||||||
guardrails=[passing_guardrail, failing_guardrail],
|
guardrails=[passing_guardrail, failing_guardrail],
|
||||||
@@ -507,7 +527,7 @@ def test_multiple_guardrails_empty_list():
|
|||||||
agent.execute_task.return_value = "no guardrails"
|
agent.execute_task.return_value = "no guardrails"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test empty guardrails list",
|
description="Test empty guardrails list",
|
||||||
expected_output="No processing",
|
expected_output="No processing",
|
||||||
guardrails=[],
|
guardrails=[],
|
||||||
@@ -531,7 +551,7 @@ def test_multiple_guardrails_with_llm_guardrails():
|
|||||||
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
|
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
|
||||||
)
|
)
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test mixed guardrail types",
|
description="Test mixed guardrail types",
|
||||||
expected_output="Mixed processing",
|
expected_output="Mixed processing",
|
||||||
guardrails=[callable_guardrail, "Ensure the output is professional"],
|
guardrails=[callable_guardrail, "Ensure the output is professional"],
|
||||||
@@ -566,7 +586,7 @@ def test_multiple_guardrails_processing_order():
|
|||||||
agent.execute_task.return_value = "base"
|
agent.execute_task.return_value = "base"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test processing order",
|
description="Test processing order",
|
||||||
expected_output="Ordered processing",
|
expected_output="Ordered processing",
|
||||||
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
guardrails=[first_guardrail, second_guardrail, third_guardrail],
|
||||||
@@ -609,7 +629,7 @@ def test_multiple_guardrails_with_pydantic_output():
|
|||||||
agent.execute_task.return_value = "test content"
|
agent.execute_task.return_value = "test content"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test guardrails with Pydantic",
|
description="Test guardrails with Pydantic",
|
||||||
expected_output="Structured output",
|
expected_output="Structured output",
|
||||||
guardrails=[json_guardrail, validation_guardrail],
|
guardrails=[json_guardrail, validation_guardrail],
|
||||||
@@ -642,7 +662,7 @@ def test_guardrails_vs_single_guardrail_mutual_exclusion():
|
|||||||
agent.execute_task.return_value = "test"
|
agent.execute_task.return_value = "test"
|
||||||
agent.crew = None
|
agent.crew = None
|
||||||
|
|
||||||
task = Task(
|
task = create_smart_task(
|
||||||
description="Test mutual exclusion",
|
description="Test mutual exclusion",
|
||||||
expected_output="Exclusion test",
|
expected_output="Exclusion test",
|
||||||
guardrail=single_guardrail, # This should be ignored
|
guardrail=single_guardrail, # This should be ignored
|
||||||
|
|||||||
Reference in New Issue
Block a user