Supporting no-code Guardrail creation (#2636)

* feat: support to define a guardrail task no-code

* feat: add auto-discovery for Guardrail code execution mode

* feat: handle malformed or invalid response from CodeInterpreterTool

* feat: allow to set unsafe_mode from Guardrail task

* feat: renaming GuardrailTask to TaskGuardrail

* feat: ensure guardrail is callable while initializing Task

* feat: remove Docker availability check from TaskGuardrail

The CodeInterpreterTool already ensures compliance with this requirement.

* refactor: replace if/raise with assert

For this use case `assert` is more appropriate choice

* test: remove useless or duplicated test

* fix: attempt to fix type-checker

* feat: support to define a task guardrail using YAML config

* refactor: simplify TaskGuardrail to use LLM for validation, no code generation

* docs: update TaskGuardrail doc strings

* refactor: drop task paramenter from TaskGuardrail

This parameter was used to get the model from the `task.agent` which is a quite bit redudant since we could propagate the llm directly
This commit is contained in:
Lucas Gomide
2025-04-30 11:47:58 -03:00
committed by GitHub
parent 94b1a6cfb8
commit 015e1a41b2
18 changed files with 4935 additions and 1162 deletions

View File

@@ -483,6 +483,7 @@ class LLM(BaseLLM):
full_response += chunk_content
# Emit the chunk event
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(chunk=chunk_content),
@@ -611,6 +612,7 @@ class LLM(BaseLLM):
return full_response
# Emit failed event and re-raise the exception
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=str(e)),
@@ -633,7 +635,7 @@ class LLM(BaseLLM):
current_tool_accumulator.function.arguments += (
tool_call.function.arguments
)
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -806,6 +808,7 @@ class LLM(BaseLLM):
function_name, lambda: None
) # Ensure fn is always a callable
logging.error(f"Error executing function '{function_name}': {e}")
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
@@ -843,6 +846,7 @@ class LLM(BaseLLM):
LLMContextLengthExceededException: If input exceeds model's context limit
"""
# --- 1) Emit call started event
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(
@@ -891,6 +895,7 @@ class LLM(BaseLLM):
# whether to summarize the content or abort based on the respect_context_window flag
raise
except Exception as e:
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=str(e)),
@@ -905,6 +910,7 @@ class LLM(BaseLLM):
response (str): The response from the LLM call.
call_type (str): The type of call, either "tool_call" or "llm_call".
"""
assert hasattr(crewai_event_bus, "emit")
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(response=response, call_type=call_type),

View File

@@ -246,6 +246,9 @@ def CrewBase(cls: T) -> T:
callback_functions[callback]() for callback in callbacks
]
if guardrail := task_info.get("guardrail"):
self.tasks_config[task_name]["guardrail"] = guardrail
# Include base class (qual)name in the wrapper class (qual)name.
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"

View File

@@ -140,9 +140,9 @@ class Task(BaseModel):
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function to validate task output before proceeding to next task",
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
@@ -157,8 +157,12 @@ class Task(BaseModel):
@field_validator("guardrail")
@classmethod
def validate_guardrail_function(cls, v: Optional[Callable]) -> Optional[Callable]:
"""Validate that the guardrail function has the correct signature and behavior.
def validate_guardrail_function(
cls, v: Optional[str | Callable]
) -> Optional[str | Callable]:
"""
If v is a callable, validate that the guardrail function has the correct signature and behavior.
If v is a string, return it as is.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one parameter (the TaskOutput)
@@ -171,16 +175,16 @@ class Task(BaseModel):
- Clear error messages help users debug guardrail implementation issues
Args:
v: The guardrail function to validate
v: The guardrail function to validate or a string describing the guardrail task
Returns:
The validated guardrail function
The validated guardrail function or a string describing the guardrail task
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any]
"""
if v is not None:
if v is not None and callable(v):
sig = inspect.signature(v)
positional_args = [
param
@@ -211,6 +215,7 @@ class Task(BaseModel):
)
return v
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_original_description: Optional[str] = PrivateAttr(default=None)
_original_expected_output: Optional[str] = PrivateAttr(default=None)
_original_output_file: Optional[str] = PrivateAttr(default=None)
@@ -231,6 +236,20 @@ class Task(BaseModel):
)
return self
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> "Task":
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
from crewai.tasks.task_guardrail import TaskGuardrail
assert self.agent is not None
self._guardrail = TaskGuardrail(
description=self.guardrail, llm=self.agent.llm
)
return self
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
@@ -407,10 +426,8 @@ class Task(BaseModel):
output_format=self._get_output_format(),
)
if self.guardrail:
guardrail_result = GuardrailResult.from_tuple(
self.guardrail(task_output)
)
if self._guardrail:
guardrail_result = self._process_guardrail(task_output)
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
raise Exception(
@@ -464,13 +481,46 @@ class Task(BaseModel):
)
)
self._save_file(content)
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output, task=self))
crewai_event_bus.emit(
self, TaskCompletedEvent(output=task_output, task=self)
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
assert self._guardrail is not None
from crewai.utilities.events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
result = self._guardrail(task_output)
crewai_event_bus.emit(
self,
TaskGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count
),
)
guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit(
self,
TaskGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=self.retry_count,
),
)
return guardrail_result
def prompt(self) -> str:
"""Prompt the task.

View File

@@ -0,0 +1,92 @@
from typing import Any, Optional, Tuple
from pydantic import BaseModel, Field
from crewai.agent import Agent, LiteAgentOutput
from crewai.llm import LLM
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
class TaskGuardrailResult(BaseModel):
valid: bool = Field(
description="Whether the task output complies with the guardrail"
)
feedback: str | None = Field(
description="A feedback about the task output if it is not valid",
default=None,
)
class TaskGuardrail:
"""It validates the output of another task using an LLM.
This class is used to validate the output from a Task based on specified criteria.
It uses an LLM to validate the output and provides a feedback if the output is not valid.
Args:
description (str): The description of the validation criteria.
llm (LLM, optional): The language model to use for code generation.
"""
def __init__(
self,
description: str,
llm: LLM,
):
self.description = description
self.llm: LLM = llm
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
agent = Agent(
role="Guardrail Agent",
goal="Validate the output of the task",
backstory="You are a expert at validating the output of a task. By providing effective feedback if the output is not valid.",
llm=self.llm,
)
query = f"""
Ensure the following task result complies with the given guardrail.
Task result:
{task_output.raw}
Guardrail:
{self.description}
Your task:
- Confirm if the Task result complies with the guardrail.
- If not, provide clear feedback explaining what is wrong (e.g., by how much it violates the rule, or what specific part fails).
- Focus only on identifying issues — do not propose corrections.
- If the Task result complies with the guardrail, saying that is valid
"""
result = agent.kickoff(query, response_format=TaskGuardrailResult)
return result
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
"""Validates the output of a task based on specified criteria.
Args:
task_output (TaskOutput): The output to be validated.
Returns:
Tuple[bool, Any]: A tuple containing:
- bool: True if validation passed, False otherwise
- Any: The validation result or error message
"""
try:
result = self._validate_output(task_output)
assert isinstance(
result.pydantic, TaskGuardrailResult
), "The guardrail result is not a valid pydantic model"
if result.pydantic.valid:
return True, task_output.raw
else:
return False, result.pydantic.feedback
except Exception as e:
return False, f"Error while validating the task output: {str(e)}"

View File

@@ -9,6 +9,10 @@ from .crew_events import (
CrewTestCompletedEvent,
CrewTestFailedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
)
from .agent_events import (
AgentExecutionStartedEvent,
AgentExecutionCompletedEvent,

View File

@@ -34,6 +34,10 @@ from .task_events import (
TaskFailedEvent,
TaskStartedEvent,
)
from .task_guardrail_events import (
TaskGuardrailCompletedEvent,
TaskGuardrailStartedEvent,
)
from .tool_usage_events import (
ToolUsageErrorEvent,
ToolUsageFinishedEvent,
@@ -68,4 +72,6 @@ EventTypes = Union[
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMStreamChunkEvent,
TaskGuardrailStartedEvent,
TaskGuardrailCompletedEvent,
]

View File

@@ -0,0 +1,38 @@
from typing import Any, Callable, Optional, Union
from crewai.utilities.events.base_events import BaseEvent
class TaskGuardrailStartedEvent(BaseEvent):
"""Event emitted when a guardrail task starts
Attributes:
guardrail: The guardrail callable or TaskGuardrail instance
retry_count: The number of times the guardrail has been retried
"""
type: str = "task_guardrail_started"
guardrail: Union[str, Callable]
retry_count: int
def __init__(self, **data):
from inspect import getsource
from crewai.tasks.task_guardrail import TaskGuardrail
super().__init__(**data)
if isinstance(self.guardrail, TaskGuardrail):
self.guardrail = self.guardrail.description.strip()
elif isinstance(self.guardrail, Callable):
self.guardrail = getsource(self.guardrail).strip()
class TaskGuardrailCompletedEvent(BaseEvent):
"""Event emitted when a guardrail task completes"""
type: str = "task_guardrail_completed"
success: bool
result: Any
error: Optional[str] = None
retry_count: int