mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-25 16:18:13 +00:00
feat: support to define a guardrail task no-code
This commit is contained in:
@@ -1,11 +1,16 @@
|
||||
"""Tests for task guardrails functionality."""
|
||||
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.task import Task
|
||||
from crewai import Agent, Task
|
||||
from crewai.llm import LLM
|
||||
from crewai.tasks.guardrail_task import GuardrailTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.utilities.events import (
|
||||
GuardrailTaskCompletedEvent,
|
||||
GuardrailTaskStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
|
||||
def test_task_without_guardrail():
|
||||
@@ -22,7 +27,7 @@ def test_task_without_guardrail():
|
||||
assert result.raw == "test result"
|
||||
|
||||
|
||||
def test_task_with_successful_guardrail():
|
||||
def test_task_with_successful_guardrail_func():
|
||||
"""Test that successful guardrail validation passes transformed result."""
|
||||
|
||||
def guardrail(result: TaskOutput):
|
||||
@@ -127,3 +132,190 @@ def test_guardrail_error_in_context():
|
||||
|
||||
assert "Task failed guardrail validation" in str(exc_info.value)
|
||||
assert "Expected JSON, got string" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_agent():
|
||||
return Agent(role="Test Agent", goal="Test Goal", backstory="Test Backstory")
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_using_llm(sample_agent):
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail="Ensure the output is equal to 'good result'",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai.tasks.guardrail_task.GuardrailTask.__call__",
|
||||
side_effect=[(False, "bad result"), (True, "good result")],
|
||||
) as mock_guardrail:
|
||||
task.execute_sync(agent=sample_agent)
|
||||
|
||||
assert mock_guardrail.call_count == 2
|
||||
|
||||
task.guardrail = GuardrailTask(
|
||||
description="Ensure the output is equal to 'good result'",
|
||||
llm=LLM(model="gpt-4o-mini"),
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai.tasks.guardrail_task.GuardrailTask.__call__",
|
||||
side_effect=[(False, "bad result"), (True, "good result")],
|
||||
) as mock_guardrail:
|
||||
task.execute_sync(agent=sample_agent)
|
||||
|
||||
assert mock_guardrail.call_count == 2
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def task_output():
|
||||
return TaskOutput(
|
||||
raw="Test output",
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
agent="Test Agent",
|
||||
)
|
||||
|
||||
|
||||
def test_guardrail_task_initialization_no_llm(task_output):
|
||||
"""Test GuardrailTask initialization fails without LLM"""
|
||||
with pytest.raises(ValueError, match="Provide a valid LLM to the GuardrailTask"):
|
||||
GuardrailTask(description="Test")(task_output)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
llm = Mock(spec=LLM)
|
||||
llm.call.return_value = """
|
||||
output = 'Sample book data'
|
||||
if isinstance(output, str):
|
||||
result = (True, output)
|
||||
else:
|
||||
result = (False, 'Invalid output format')
|
||||
print(result)
|
||||
"""
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"tool_run_output",
|
||||
[
|
||||
{
|
||||
"output": "(True, 'Valid output')",
|
||||
"expected_result": True,
|
||||
"expected_output": "Valid output",
|
||||
},
|
||||
{
|
||||
"output": "(False, 'Invalid output format')",
|
||||
"expected_result": False,
|
||||
"expected_output": "Invalid output format",
|
||||
},
|
||||
{
|
||||
"output": "Something went wrong while running the code, Invalid output format",
|
||||
"expected_result": False,
|
||||
"expected_output": "Something went wrong while running the code, Invalid output format",
|
||||
},
|
||||
{
|
||||
"output": "No result variable found",
|
||||
"expected_result": False,
|
||||
"expected_output": "No result variable found",
|
||||
},
|
||||
{
|
||||
"output": (False, "Invalid output format"),
|
||||
"expected_result": False,
|
||||
"expected_output": "Invalid output format",
|
||||
},
|
||||
],
|
||||
)
|
||||
@patch("crewai_tools.CodeInterpreterTool.run")
|
||||
def test_guardrail_task_execute_code(mock_run, mock_llm, tool_run_output, task_output):
|
||||
mock_run.return_value = tool_run_output["output"]
|
||||
|
||||
guardrail = GuardrailTask(description="Test validation", llm=mock_llm)
|
||||
|
||||
result = guardrail(task_output)
|
||||
assert result[0] == tool_run_output["expected_result"]
|
||||
assert result[1] == tool_run_output["expected_output"]
|
||||
|
||||
|
||||
@patch("crewai_tools.CodeInterpreterTool.run")
|
||||
def test_guardrail_using_additional_instructions(mock_run, mock_llm, task_output):
|
||||
mock_run.return_value = "(True, 'Valid output')"
|
||||
additional_instructions = (
|
||||
"This is an additional instruction created by the user follow it strictly"
|
||||
)
|
||||
guardrail = GuardrailTask(
|
||||
description="Test validation",
|
||||
llm=mock_llm,
|
||||
additional_instructions=additional_instructions,
|
||||
)
|
||||
|
||||
guardrail(task_output)
|
||||
|
||||
assert additional_instructions in str(mock_llm.call.call_args)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_guardrail_emits_events(sample_agent):
|
||||
started_guardrail = []
|
||||
completed_guardrail = []
|
||||
|
||||
with crewai_event_bus.scoped_handlers():
|
||||
|
||||
@crewai_event_bus.on(GuardrailTaskStartedEvent)
|
||||
def handle_guardrail_started(source, event):
|
||||
started_guardrail.append(
|
||||
{"guardrail": event.guardrail, "retry_count": event.retry_count}
|
||||
)
|
||||
|
||||
@crewai_event_bus.on(GuardrailTaskCompletedEvent)
|
||||
def handle_guardrail_completed(source, event):
|
||||
completed_guardrail.append(
|
||||
{
|
||||
"success": event.success,
|
||||
"result": event.result,
|
||||
"error": event.error,
|
||||
"retry_count": event.retry_count,
|
||||
}
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Output",
|
||||
guardrail="Ensure the output is equal to 'good result'",
|
||||
)
|
||||
|
||||
with patch(
|
||||
"crewai.tasks.guardrail_task.GuardrailTask.__call__",
|
||||
side_effect=[(False, "bad result"), (True, "good result")],
|
||||
):
|
||||
task.execute_sync(agent=sample_agent)
|
||||
|
||||
expected_started_events = [
|
||||
{
|
||||
"guardrail": "Ensure the output is equal to 'good result'",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{
|
||||
"guardrail": "Ensure the output is equal to 'good result'",
|
||||
"retry_count": 1,
|
||||
},
|
||||
]
|
||||
expected_completed_events = [
|
||||
{
|
||||
"success": False,
|
||||
"result": None,
|
||||
"error": "bad result",
|
||||
"retry_count": 0,
|
||||
},
|
||||
{
|
||||
"success": True,
|
||||
"result": "good result",
|
||||
"error": None,
|
||||
"retry_count": 1,
|
||||
},
|
||||
]
|
||||
assert started_guardrail == expected_started_events
|
||||
assert completed_guardrail == expected_completed_events
|
||||
|
||||
Reference in New Issue
Block a user