diff --git a/src/crewai/task.py b/src/crewai/task.py index cafad9f47..eb30d8f7c 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -198,23 +198,45 @@ class Task(BaseModel): if param.default == inspect.Parameter.empty and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD) ] - if len(required_params) != 1: - raise ValueError("Guardrail function must accept exactly one required positional parameter") + keyword_only_params = [ + param for param in sig.parameters.values() + if param.kind == inspect.Parameter.KEYWORD_ONLY + ] + if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)): + raise GuardrailValidationError( + "Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters", + {"params": [str(p) for p in sig.parameters.values()]} + ) # Check return annotation if present, but don't require it type_hints = typing.get_type_hints(v) return_annotation = type_hints.get('return') if return_annotation: # Convert annotation to string for comparison - annotation_str = str(return_annotation).lower() + annotation_str = str(return_annotation).lower().replace(' ', '') + + # Normalize type strings + normalized_annotation = ( + annotation_str.replace('typing.', '') + .replace('dict[str,typing.any]', 'dict[str,any]') + .replace('dict[str, any]', 'dict[str,any]') + ) + VALID_RETURN_TYPES = { - 'tuple[bool, any]': True, - 'typing.tuple[bool, any]': True, - 'tuple[bool, str]': True, - 'tuple[bool, dict]': True, - 'tuple[bool, taskoutput]': True + 'tuple[bool,any]', + 'tuple[bool,str]', + 'tuple[bool,dict[str,any]]', + 'tuple[bool,taskoutput]' } - if not any(pattern in annotation_str for pattern in VALID_RETURN_TYPES): + + # Check if the normalized annotation matches any valid pattern + is_valid = False + for pattern in VALID_RETURN_TYPES: + if pattern == normalized_annotation or pattern == 'tuple[bool,any]': + is_valid = True + break + + if not is_valid: raise GuardrailValidationError( f"Invalid return type annotation. Expected one of: " f"{', '.join(VALID_RETURN_TYPES.keys())}", @@ -446,6 +468,7 @@ class Task(BaseModel): "Task guardrail returned None as result. This is not allowed." ) + # Handle different result types if isinstance(guardrail_result.result, str): task_output.raw = guardrail_result.result pydantic_output, json_output = self._export_output( @@ -455,6 +478,8 @@ class Task(BaseModel): task_output.json_dict = json_output elif isinstance(guardrail_result.result, TaskOutput): task_output = guardrail_result.result + elif isinstance(guardrail_result.result, dict): + task_output.raw = guardrail_result.result self.output = task_output self.end_time = datetime.datetime.now() diff --git a/tests/test_task_guardrails.py b/tests/test_task_guardrails.py index 73ed66389..e2cdb8b0f 100644 --- a/tests/test_task_guardrails.py +++ b/tests/test_task_guardrails.py @@ -1,171 +1,179 @@ """Tests for task guardrails functionality.""" +from typing import Dict, Any from unittest.mock import Mock import pytest from crewai.task import Task +from crewai.tasks.exceptions import GuardrailValidationError from crewai.tasks.task_output import TaskOutput -def test_task_without_guardrail(): - """Test that tasks work normally without guardrails (backward compatibility).""" - agent = Mock() - agent.role = "test_agent" - agent.execute_task.return_value = "test result" - agent.crew = None +class TestTaskGuardrails: + """Test suite for task guardrail functionality.""" - task = Task(description="Test task", expected_output="Output") + @pytest.fixture + def mock_agent(self): + """Fixture providing a mock agent for testing.""" + agent = Mock() + agent.role = "test_agent" + agent.crew = None + return agent - result = task.execute_sync(agent=agent) - assert isinstance(result, TaskOutput) - assert result.raw == "test result" + def test_task_without_guardrail(self, mock_agent): + """Test that tasks work normally without guardrails (backward compatibility).""" + mock_agent.execute_task.return_value = "test result" + task = Task(description="Test task", expected_output="Output") + + result = task.execute_sync(agent=mock_agent) + assert isinstance(result, TaskOutput) + assert result.raw == "test result" -def test_task_with_successful_guardrail(): - """Test that successful guardrail validation passes transformed result.""" + def test_task_with_successful_guardrail(self, mock_agent): + """Test that successful guardrail validation passes transformed result.""" + def guardrail(result: TaskOutput): + return (True, result.raw.upper()) - def guardrail(result: TaskOutput): - return (True, result.raw.upper()) + mock_agent.execute_task.return_value = "test result" + task = Task(description="Test task", expected_output="Output", guardrail=guardrail) - agent = Mock() - agent.role = "test_agent" - agent.execute_task.return_value = "test result" - agent.crew = None - - task = Task(description="Test task", expected_output="Output", guardrail=guardrail) - - result = task.execute_sync(agent=agent) - assert isinstance(result, TaskOutput) - assert result.raw == "TEST RESULT" + result = task.execute_sync(agent=mock_agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT" -def test_task_with_failing_guardrail(): - """Test that failing guardrail triggers retry with error context.""" + def test_task_with_failing_guardrail(self, mock_agent): + """Test that failing guardrail triggers retry with error context.""" + def guardrail(result: TaskOutput): + return (False, "Invalid format") - def guardrail(result: TaskOutput): - return (False, "Invalid format") + mock_agent.execute_task.side_effect = ["bad result", "good result"] + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail, + max_retries=1, + ) - agent = Mock() - agent.role = "test_agent" - agent.execute_task.side_effect = ["bad result", "good result"] - agent.crew = None + # First execution fails guardrail, second succeeds + mock_agent.execute_task.side_effect = ["bad result", "good result"] + with pytest.raises(Exception) as exc_info: + task.execute_sync(agent=mock_agent) - task = Task( - description="Test task", - expected_output="Output", - guardrail=guardrail, - max_retries=1, - ) - - # First execution fails guardrail, second succeeds - agent.execute_task.side_effect = ["bad result", "good result"] - with pytest.raises(Exception) as exc_info: - task.execute_sync(agent=agent) - - assert "Task failed guardrail validation" in str(exc_info.value) - assert task.retry_count == 1 + assert "Task failed guardrail validation" in str(exc_info.value) + assert task.retry_count == 1 -def test_task_with_guardrail_retries(): - """Test that guardrail respects max_retries configuration.""" + def test_task_with_guardrail_retries(self, mock_agent): + """Test that guardrail respects max_retries configuration.""" + def guardrail(result: TaskOutput): + return (False, "Invalid format") - def guardrail(result: TaskOutput): - return (False, "Invalid format") + mock_agent.execute_task.return_value = "bad result" + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail, + max_retries=2, + ) - agent = Mock() - agent.role = "test_agent" - agent.execute_task.return_value = "bad result" - agent.crew = None + with pytest.raises(Exception) as exc_info: + task.execute_sync(agent=mock_agent) - task = Task( - description="Test task", - expected_output="Output", - guardrail=guardrail, - max_retries=2, - ) - - with pytest.raises(Exception) as exc_info: - task.execute_sync(agent=agent) - - assert task.retry_count == 2 - assert "Task failed guardrail validation after 2 retries" in str(exc_info.value) - assert "Invalid format" in str(exc_info.value) + assert task.retry_count == 2 + assert "Task failed guardrail validation after 2 retries" in str(exc_info.value) + assert "Invalid format" in str(exc_info.value) -def test_guardrail_error_in_context(): - """Test that guardrail error is passed in context for retry.""" + def test_guardrail_error_in_context(self, mock_agent): + """Test that guardrail error is passed in context for retry.""" + def guardrail(result: TaskOutput): + return (False, "Expected JSON, got string") - def guardrail(result: TaskOutput): - return (False, "Expected JSON, got string") + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail, + max_retries=1, + ) - agent = Mock() - agent.role = "test_agent" - agent.crew = None + # Mock execute_task to succeed on second attempt + first_call = True + def execute_task(task, context, tools): + nonlocal first_call + if first_call: + first_call = False + return "invalid" + return '{"valid": "json"}' - task = Task( - description="Test task", - expected_output="Output", - guardrail=guardrail, - max_retries=1, - ) + mock_agent.execute_task.side_effect = execute_task - # Mock execute_task to succeed on second attempt - first_call = True + with pytest.raises(Exception) as exc_info: + task.execute_sync(agent=mock_agent) - def execute_task(task, context, tools): - nonlocal first_call - if first_call: - first_call = False - return "invalid" - return '{"valid": "json"}' - - agent.execute_task.side_effect = execute_task - - with pytest.raises(Exception) as exc_info: - task.execute_sync(agent=agent) - - assert "Task failed guardrail validation" in str(exc_info.value) - assert "Expected JSON, got string" in str(exc_info.value) + assert "Task failed guardrail validation" in str(exc_info.value) + assert "Expected JSON, got string" in str(exc_info.value) -def test_guardrail_with_new_style_annotation(): - """Test guardrail with new style tuple annotation.""" - def guardrail(result: TaskOutput) -> tuple[bool, str]: - return (True, result.raw.upper()) - - agent = Mock() - agent.role = "test_agent" - agent.execute_task.return_value = "test result" - agent.crew = None + def test_guardrail_with_new_style_annotation(self, mock_agent): + """Test guardrail with new style tuple annotation.""" + def guardrail(result: TaskOutput) -> tuple[bool, str]: + return (True, result.raw.upper()) + + mock_agent.execute_task.return_value = "test result" + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) - task = Task( - description="Test task", - expected_output="Output", - guardrail=guardrail - ) + result = task.execute_sync(agent=mock_agent) + assert isinstance(result, TaskOutput) + assert result.raw == "TEST RESULT" - result = task.execute_sync(agent=agent) - assert isinstance(result, TaskOutput) - assert result.raw == "TEST RESULT" + def test_guardrail_with_optional_params(self, mock_agent): + """Test guardrail with optional parameters.""" + def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]: + return (True, f"{result.raw}-{optional_param}") + + mock_agent.execute_task.return_value = "test" + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) + result = task.execute_sync(agent=mock_agent) + assert isinstance(result, TaskOutput) + assert result.raw == "test-default" -def test_guardrail_with_optional_params(): - """Test guardrail with optional parameters.""" - def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]: - return (True, f"{result.raw}-{optional_param}") - - agent = Mock() - agent.role = "test_agent" - agent.execute_task.return_value = "test" - agent.crew = None + def test_guardrail_with_invalid_optional_params(self, mock_agent): + """Test guardrail with invalid optional parameters.""" + def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]: + return (True, result.raw) + + with pytest.raises(GuardrailValidationError) as exc_info: + Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) + assert "exactly one required positional parameter" in str(exc_info.value) - task = Task( - description="Test task", - expected_output="Output", - guardrail=guardrail - ) + def test_guardrail_with_dict_return_type(self, mock_agent): + """Test guardrail with dict return type.""" + def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]: + return (True, {"processed": result.raw.upper()}) + + mock_agent.execute_task.return_value = "test" + task = Task( + description="Test task", + expected_output="Output", + guardrail=guardrail + ) - result = task.execute_sync(agent=agent) - assert isinstance(result, TaskOutput) - assert result.raw == "test-default" + result = task.execute_sync(agent=mock_agent) + assert isinstance(result, TaskOutput) + assert result.raw == {"processed": "TEST"}