diff --git a/src/crewai/task.py b/src/crewai/task.py index 8c5a95171..3f543888c 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -112,7 +112,7 @@ class Task(BaseModel): default=None, ) processed_by_agents: Set[str] = Field(default_factory=set) - guardrail: Optional[Callable[[Any], Tuple[bool, Any]]] = Field( + guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field( default=None, description="Function to validate task output before proceeding to next task" ) @@ -283,9 +283,20 @@ class Task(BaseModel): tools=tools, ) - # Add guardrail validation + pydantic_output, json_output = self._export_output(result) + task_output = TaskOutput( + name=self.name, + description=self.description, + expected_output=self.expected_output, + raw=result, + pydantic=pydantic_output, + json_dict=json_output, + agent=agent.role, + output_format=self._get_output_format(), + ) + if self.guardrail: - guardrail_result = GuardrailResult.from_tuple(self.guardrail(result)) + guardrail_result = GuardrailResult.from_tuple(self.guardrail(task_output)) if not guardrail_result.success: if self.retry_count >= self.max_retries: raise Exception( @@ -297,25 +308,15 @@ class Task(BaseModel): context = f"Previous attempt failed validation: {guardrail_result.error}\nPlease try again." return self._execute_core(agent, context, tools) - # Ensure result is not None before assignment if guardrail_result.result is None: raise Exception( "Task guardrail returned None as result. This is not allowed." ) - result = guardrail_result.result + task_output.raw = guardrail_result.result + pydantic_output, json_output = self._export_output(guardrail_result.result) + task_output.pydantic = pydantic_output + task_output.json_dict = json_output - pydantic_output, json_output = self._export_output(result) - - task_output = TaskOutput( - name=self.name, - description=self.description, - expected_output=self.expected_output, - raw=result, - pydantic=pydantic_output, - json_dict=json_output, - agent=agent.role, - output_format=self._get_output_format(), - ) self.output = task_output self._set_end_execution_time(start_time) diff --git a/tests/test_task_guardrails.py b/tests/test_task_guardrails.py index 3d1f729c5..338b771b8 100644 --- a/tests/test_task_guardrails.py +++ b/tests/test_task_guardrails.py @@ -26,8 +26,8 @@ def test_task_without_guardrail(): def test_task_with_successful_guardrail(): """Test that successful guardrail validation passes transformed result.""" - def guardrail(result): - return (True, result.upper()) + def guardrail(result: TaskOutput): + return (True, result.raw.upper()) agent = Mock() agent.role = "test_agent" @@ -47,7 +47,7 @@ def test_task_with_successful_guardrail(): def test_task_with_failing_guardrail(): """Test that failing guardrail triggers retry with error context.""" - def guardrail(result): + def guardrail(result: TaskOutput): return (False, "Invalid format") agent = Mock() @@ -76,7 +76,7 @@ def test_task_with_failing_guardrail(): def test_task_with_guardrail_retries(): """Test that guardrail respects max_retries configuration.""" - def guardrail(result): + def guardrail(result: TaskOutput): return (False, "Invalid format") agent = Mock() @@ -101,7 +101,7 @@ def test_task_with_guardrail_retries(): def test_guardrail_error_in_context(): """Test that guardrail error is passed in context for retry.""" - def guardrail(result): + def guardrail(result: TaskOutput): return (False, "Expected JSON, got string") agent = Mock()