refactor: implement code review suggestions

- Use typing.get_type_hints for better type checking
- Add proper handling of dict return types
- Improve parameter validation for keyword-only params
- Add comprehensive test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2025-02-20 16:54:27 +00:00
parent 0e086d348a
commit 31e8b9d7f2
2 changed files with 172 additions and 139 deletions

View File

@@ -198,23 +198,45 @@ class Task(BaseModel):
if param.default == inspect.Parameter.empty if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
] ]
if len(required_params) != 1: keyword_only_params = [
raise ValueError("Guardrail function must accept exactly one required positional parameter") param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
{"params": [str(p) for p in sig.parameters.values()]}
)
# Check return annotation if present, but don't require it # Check return annotation if present, but don't require it
type_hints = typing.get_type_hints(v) type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return') return_annotation = type_hints.get('return')
if return_annotation: if return_annotation:
# Convert annotation to string for comparison # Convert annotation to string for comparison
annotation_str = str(return_annotation).lower() annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = { VALID_RETURN_TYPES = {
'tuple[bool, any]': True, 'tuple[bool,any]',
'typing.tuple[bool, any]': True, 'tuple[bool,str]',
'tuple[bool, str]': True, 'tuple[bool,dict[str,any]]',
'tuple[bool, dict]': True, 'tuple[bool,taskoutput]'
'tuple[bool, taskoutput]': True
} }
if not any(pattern in annotation_str for pattern in VALID_RETURN_TYPES):
# Check if the normalized annotation matches any valid pattern
is_valid = False
for pattern in VALID_RETURN_TYPES:
if pattern == normalized_annotation or pattern == 'tuple[bool,any]':
is_valid = True
break
if not is_valid:
raise GuardrailValidationError( raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: " f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES.keys())}", f"{', '.join(VALID_RETURN_TYPES.keys())}",
@@ -446,6 +468,7 @@ class Task(BaseModel):
"Task guardrail returned None as result. This is not allowed." "Task guardrail returned None as result. This is not allowed."
) )
# Handle different result types
if isinstance(guardrail_result.result, str): if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output( pydantic_output, json_output = self._export_output(
@@ -455,6 +478,8 @@ class Task(BaseModel):
task_output.json_dict = json_output task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput): elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result task_output = guardrail_result.result
elif isinstance(guardrail_result.result, dict):
task_output.raw = guardrail_result.result
self.output = task_output self.output = task_output
self.end_time = datetime.datetime.now() self.end_time = datetime.datetime.now()

View File

@@ -1,56 +1,55 @@
"""Tests for task guardrails functionality.""" """Tests for task guardrails functionality."""
from typing import Dict, Any
from unittest.mock import Mock from unittest.mock import Mock
import pytest import pytest
from crewai.task import Task from crewai.task import Task
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
def test_task_without_guardrail(): class TestTaskGuardrails:
"""Test that tasks work normally without guardrails (backward compatibility).""" """Test suite for task guardrail functionality."""
@pytest.fixture
def mock_agent(self):
"""Fixture providing a mock agent for testing."""
agent = Mock() agent = Mock()
agent.role = "test_agent" agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None agent.crew = None
return agent
def test_task_without_guardrail(self, mock_agent):
"""Test that tasks work normally without guardrails (backward compatibility)."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output") task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
assert result.raw == "test result" assert result.raw == "test result"
def test_task_with_successful_guardrail(): def test_task_with_successful_guardrail(self, mock_agent):
"""Test that successful guardrail validation passes transformed result.""" """Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput): def guardrail(result: TaskOutput):
return (True, result.raw.upper()) return (True, result.raw.upper())
agent = Mock() mock_agent.execute_task.return_value = "test result"
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail) task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT" assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail(): def test_task_with_failing_guardrail(self, mock_agent):
"""Test that failing guardrail triggers retry with error context.""" """Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput): def guardrail(result: TaskOutput):
return (False, "Invalid format") return (False, "Invalid format")
agent = Mock() mock_agent.execute_task.side_effect = ["bad result", "good result"]
agent.role = "test_agent"
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
task = Task( task = Task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
@@ -59,25 +58,20 @@ def test_task_with_failing_guardrail():
) )
# First execution fails guardrail, second succeeds # First execution fails guardrail, second succeeds
agent.execute_task.side_effect = ["bad result", "good result"] mock_agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent) task.execute_sync(agent=mock_agent)
assert "Task failed guardrail validation" in str(exc_info.value) assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1 assert task.retry_count == 1
def test_task_with_guardrail_retries(): def test_task_with_guardrail_retries(self, mock_agent):
"""Test that guardrail respects max_retries configuration.""" """Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput): def guardrail(result: TaskOutput):
return (False, "Invalid format") return (False, "Invalid format")
agent = Mock() mock_agent.execute_task.return_value = "bad result"
agent.role = "test_agent"
agent.execute_task.return_value = "bad result"
agent.crew = None
task = Task( task = Task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
@@ -86,23 +80,18 @@ def test_task_with_guardrail_retries():
) )
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent) task.execute_sync(agent=mock_agent)
assert task.retry_count == 2 assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value) assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value) assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context(): def test_guardrail_error_in_context(self, mock_agent):
"""Test that guardrail error is passed in context for retry.""" """Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput): def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string") return (False, "Expected JSON, got string")
agent = Mock()
agent.role = "test_agent"
agent.crew = None
task = Task( task = Task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
@@ -112,7 +101,6 @@ def test_guardrail_error_in_context():
# Mock execute_task to succeed on second attempt # Mock execute_task to succeed on second attempt
first_call = True first_call = True
def execute_task(task, context, tools): def execute_task(task, context, tools):
nonlocal first_call nonlocal first_call
if first_call: if first_call:
@@ -120,52 +108,72 @@ def test_guardrail_error_in_context():
return "invalid" return "invalid"
return '{"valid": "json"}' return '{"valid": "json"}'
agent.execute_task.side_effect = execute_task mock_agent.execute_task.side_effect = execute_task
with pytest.raises(Exception) as exc_info: with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent) task.execute_sync(agent=mock_agent)
assert "Task failed guardrail validation" in str(exc_info.value) assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value) assert "Expected JSON, got string" in str(exc_info.value)
def test_guardrail_with_new_style_annotation(): def test_guardrail_with_new_style_annotation(self, mock_agent):
"""Test guardrail with new style tuple annotation.""" """Test guardrail with new style tuple annotation."""
def guardrail(result: TaskOutput) -> tuple[bool, str]: def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper()) return (True, result.raw.upper())
agent = Mock() mock_agent.execute_task.return_value = "test result"
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task( task = Task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=guardrail guardrail=guardrail
) )
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT" assert result.raw == "TEST RESULT"
def test_guardrail_with_optional_params(self, mock_agent):
def test_guardrail_with_optional_params():
"""Test guardrail with optional parameters.""" """Test guardrail with optional parameters."""
def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]: def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}") return (True, f"{result.raw}-{optional_param}")
agent = Mock() mock_agent.execute_task.return_value = "test"
agent.role = "test_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = Task( task = Task(
description="Test task", description="Test task",
expected_output="Output", expected_output="Output",
guardrail=guardrail guardrail=guardrail
) )
result = task.execute_sync(agent=agent) result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput) assert isinstance(result, TaskOutput)
assert result.raw == "test-default" assert result.raw == "test-default"
def test_guardrail_with_invalid_optional_params(self, mock_agent):
"""Test guardrail with invalid optional parameters."""
def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]:
return (True, result.raw)
with pytest.raises(GuardrailValidationError) as exc_info:
Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
assert "exactly one required positional parameter" in str(exc_info.value)
def test_guardrail_with_dict_return_type(self, mock_agent):
"""Test guardrail with dict return type."""
def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]:
return (True, {"processed": result.raw.upper()})
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == {"processed": "TEST"}