refactor: implement code review suggestions

- Use typing.get_type_hints for better type checking
- Add proper handling of dict return types
- Improve parameter validation for keyword-only params
- Add comprehensive test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-20 16:54:27 +00:00
parent 0e086d348a
commit 31e8b9d7f2
2 changed files with 172 additions and 139 deletions

View File

@@ -198,23 +198,45 @@ class Task(BaseModel):
if param.default == inspect.Parameter.empty if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
] ]
if len(required_params) != 1: keyword_only_params = [
raise ValueError("Guardrail function must accept exactly one required positional parameter") param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
{"params": [str(p) for p in sig.parameters.values()]}
)
# Check return annotation if present, but don't require it # Check return annotation if present, but don't require it
type_hints = typing.get_type_hints(v) type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return') return_annotation = type_hints.get('return')
if return_annotation: if return_annotation:
# Convert annotation to string for comparison # Convert annotation to string for comparison
annotation_str = str(return_annotation).lower() annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = { VALID_RETURN_TYPES = {
'tuple[bool, any]': True, 'tuple[bool,any]',
'typing.tuple[bool, any]': True, 'tuple[bool,str]',
'tuple[bool, str]': True, 'tuple[bool,dict[str,any]]',
'tuple[bool, dict]': True, 'tuple[bool,taskoutput]'
'tuple[bool, taskoutput]': True
} }
if not any(pattern in annotation_str for pattern in VALID_RETURN_TYPES):
# Check if the normalized annotation matches any valid pattern
is_valid = False
for pattern in VALID_RETURN_TYPES:
if pattern == normalized_annotation or pattern == 'tuple[bool,any]':
is_valid = True
break
if not is_valid:
raise GuardrailValidationError( raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: " f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES.keys())}", f"{', '.join(VALID_RETURN_TYPES.keys())}",
@@ -446,6 +468,7 @@ class Task(BaseModel):
"Task guardrail returned None as result. This is not allowed." "Task guardrail returned None as result. This is not allowed."
) )
# Handle different result types
if isinstance(guardrail_result.result, str): if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output( pydantic_output, json_output = self._export_output(
@@ -455,6 +478,8 @@ class Task(BaseModel):
task_output.json_dict = json_output task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput): elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result task_output = guardrail_result.result
elif isinstance(guardrail_result.result, dict):
task_output.raw = guardrail_result.result
self.output = task_output self.output = task_output
self.end_time = datetime.datetime.now() self.end_time = datetime.datetime.now()

View File

@@ -1,171 +1,179 @@
"""Tests for task guardrails functionality.""" """Tests for task guardrails functionality."""
from typing import Dict, Any
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from crewai.task import Task from crewai.task import Task
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
def test_task_without_guardrail(): class TestTaskGuardrails:
"""Test that tasks work normally without guardrails (backward compatibility).""" """Test suite for task guardrail functionality."""
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output") @pytest.fixture
def mock_agent(self):
"""Fixture providing a mock agent for testing."""
agent = Mock()
agent.role = "test_agent"
agent.crew = None
return agent
result = task.execute_sync(agent=agent) def test_task_without_guardrail(self, mock_agent):
assert isinstance(result, TaskOutput) """Test that tasks work normally without guardrails (backward compatibility)."""
assert result.raw == "test result" mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
def test_task_with_successful_guardrail(): def test_task_with_successful_guardrail(self, mock_agent):
"""Test that successful guardrail validation passes transformed result.""" """Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
def guardrail(result: TaskOutput): mock_agent.execute_task.return_value = "test result"
return (True, result.raw.upper()) task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
agent = Mock() result = task.execute_sync(agent=mock_agent)
agent.role = "test_agent" assert isinstance(result, TaskOutput)
agent.execute_task.return_value = "test result" assert result.raw == "TEST RESULT"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail(): def test_task_with_failing_guardrail(self, mock_agent):
"""Test that failing guardrail triggers retry with error context.""" """Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def guardrail(result: TaskOutput): mock_agent.execute_task.side_effect = ["bad result", "good result"]
return (False, "Invalid format") task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
agent = Mock() # First execution fails guardrail, second succeeds
agent.role = "test_agent" mock_agent.execute_task.side_effect = ["bad result", "good result"]
agent.execute_task.side_effect = ["bad result", "good result"] with pytest.raises(Exception) as exc_info:
agent.crew = None task.execute_sync(agent=mock_agent)
task = Task( assert "Task failed guardrail validation" in str(exc_info.value)
description="Test task", assert task.retry_count == 1
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
# First execution fails guardrail, second succeeds
agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
def test_task_with_guardrail_retries(): def test_task_with_guardrail_retries(self, mock_agent):
"""Test that guardrail respects max_retries configuration.""" """Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def guardrail(result: TaskOutput): mock_agent.execute_task.return_value = "bad result"
return (False, "Invalid format") task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
agent = Mock() with pytest.raises(Exception) as exc_info:
agent.role = "test_agent" task.execute_sync(agent=mock_agent)
agent.execute_task.return_value = "bad result"
agent.crew = None
task = Task( assert task.retry_count == 2
description="Test task", assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
expected_output="Output", assert "Invalid format" in str(exc_info.value)
guardrail=guardrail,
max_retries=2,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context(): def test_guardrail_error_in_context(self, mock_agent):
"""Test that guardrail error is passed in context for retry.""" """Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
def guardrail(result: TaskOutput): task = Task(
return (False, "Expected JSON, got string") description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
agent = Mock() # Mock execute_task to succeed on second attempt
agent.role = "test_agent" first_call = True
agent.crew = None def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
task = Task( mock_agent.execute_task.side_effect = execute_task
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
# Mock execute_task to succeed on second attempt with pytest.raises(Exception) as exc_info:
first_call = True task.execute_sync(agent=mock_agent)
def execute_task(task, context, tools): assert "Task failed guardrail validation" in str(exc_info.value)
nonlocal first_call assert "Expected JSON, got string" in str(exc_info.value)
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
agent.execute_task.side_effect = execute_task
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)
def test_guardrail_with_new_style_annotation(): def test_guardrail_with_new_style_annotation(self, mock_agent):
"""Test guardrail with new style tuple annotation.""" """Test guardrail with new style tuple annotation."""
def guardrail(result: TaskOutput) -> tuple[bool, str]: def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper()) return (True, result.raw.upper())
agent = Mock() mock_agent.execute_task.return_value = "test result"
agent.role = "test_agent" task = Task(
agent.execute_task.return_value = "test result" description="Test task",
agent.crew = None expected_output="Output",
guardrail=guardrail
)
task = Task( result = task.execute_sync(agent=mock_agent)
description="Test task", assert isinstance(result, TaskOutput)
expected_output="Output", assert result.raw == "TEST RESULT"
guardrail=guardrail
)
result = task.execute_sync(agent=agent) def test_guardrail_with_optional_params(self, mock_agent):
assert isinstance(result, TaskOutput) """Test guardrail with optional parameters."""
assert result.raw == "TEST RESULT" def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}")
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test-default"
def test_guardrail_with_optional_params(): def test_guardrail_with_invalid_optional_params(self, mock_agent):
"""Test guardrail with optional parameters.""" """Test guardrail with invalid optional parameters."""
def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]: def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}") return (True, result.raw)
agent = Mock() with pytest.raises(GuardrailValidationError) as exc_info:
agent.role = "test_agent" Task(
agent.execute_task.return_value = "test" description="Test task",
agent.crew = None expected_output="Output",
guardrail=guardrail
)
assert "exactly one required positional parameter" in str(exc_info.value)
task = Task( def test_guardrail_with_dict_return_type(self, mock_agent):
description="Test task", """Test guardrail with dict return type."""
expected_output="Output", def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]:
guardrail=guardrail return (True, {"processed": result.raw.upper()})
)
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
assert result.raw == "test-default" assert result.raw == {"processed": "TEST"}