refactor: Update guardrail functions to handle TaskOutput objects

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-12 04:22:33 +00:00
parent 88fd456b1f
commit 2ef341e992
2 changed files with 23 additions and 22 deletions

View File

@@ -112,7 +112,7 @@ class Task(BaseModel):
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Callable[[Any], Tuple[bool, Any]]] = Field(
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task"
)
@@ -283,9 +283,20 @@ class Task(BaseModel):
tools=tools,
)
# Add guardrail validation
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
)
if self.guardrail:
guardrail_result = GuardrailResult.from_tuple(self.guardrail(result))
guardrail_result = GuardrailResult.from_tuple(self.guardrail(task_output))
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
raise Exception(
@@ -297,25 +308,15 @@ class Task(BaseModel):
context = f"Previous attempt failed validation: {guardrail_result.error}\nPlease try again."
return self._execute_core(agent, context, tools)
# Ensure result is not None before assignment
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
result = guardrail_result.result
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(guardrail_result.result)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
)
self.output = task_output
self._set_end_execution_time(start_time)

View File

@@ -26,8 +26,8 @@ def test_task_without_guardrail():
def test_task_with_successful_guardrail():
"""Test that successful guardrail validation passes transformed result."""
def guardrail(result):
return (True, result.upper())
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
agent = Mock()
agent.role = "test_agent"
@@ -47,7 +47,7 @@ def test_task_with_successful_guardrail():
def test_task_with_failing_guardrail():
"""Test that failing guardrail triggers retry with error context."""
def guardrail(result):
def guardrail(result: TaskOutput):
return (False, "Invalid format")
agent = Mock()
@@ -76,7 +76,7 @@ def test_task_with_failing_guardrail():
def test_task_with_guardrail_retries():
"""Test that guardrail respects max_retries configuration."""
def guardrail(result):
def guardrail(result: TaskOutput):
return (False, "Invalid format")
agent = Mock()
@@ -101,7 +101,7 @@ def test_task_with_guardrail_retries():
def test_guardrail_error_in_context():
"""Test that guardrail error is passed in context for retry."""
def guardrail(result):
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
agent = Mock()