mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-15 19:18:30 +00:00
feat: support to define a guardrail task no-code
This commit is contained in:
@@ -140,7 +140,7 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
)
|
||||
processed_by_agents: Set[str] = Field(default_factory=set)
|
||||
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
|
||||
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function to validate task output before proceeding to next task",
|
||||
)
|
||||
@@ -157,8 +157,12 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
def validate_guardrail_function(cls, v: Optional[Callable]) -> Optional[Callable]:
|
||||
"""Validate that the guardrail function has the correct signature and behavior.
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[str | Callable]
|
||||
) -> Optional[str | Callable]:
|
||||
"""
|
||||
If v is a callable, validate that the guardrail function has the correct signature and behavior.
|
||||
If v is a string, return it as is.
|
||||
|
||||
While type hints provide static checking, this validator ensures runtime safety by:
|
||||
1. Verifying the function accepts exactly one parameter (the TaskOutput)
|
||||
@@ -171,16 +175,16 @@ class Task(BaseModel):
|
||||
- Clear error messages help users debug guardrail implementation issues
|
||||
|
||||
Args:
|
||||
v: The guardrail function to validate
|
||||
v: The guardrail function to validate or a string describing the guardrail task
|
||||
|
||||
Returns:
|
||||
The validated guardrail function
|
||||
The validated guardrail function or a string describing the guardrail task
|
||||
|
||||
Raises:
|
||||
ValueError: If the function signature is invalid or return annotation
|
||||
doesn't match Tuple[bool, Any]
|
||||
"""
|
||||
if v is not None:
|
||||
if v is not None and callable(v):
|
||||
sig = inspect.signature(v)
|
||||
positional_args = [
|
||||
param
|
||||
@@ -408,9 +412,7 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
if self.guardrail:
|
||||
guardrail_result = GuardrailResult.from_tuple(
|
||||
self.guardrail(task_output)
|
||||
)
|
||||
guardrail_result = self._process_guardrail(task_output)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.max_retries:
|
||||
raise Exception(
|
||||
@@ -464,13 +466,52 @@ class Task(BaseModel):
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output, task=self))
|
||||
crewai_event_bus.emit(
|
||||
self, TaskCompletedEvent(output=task_output, task=self)
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
|
||||
if self.guardrail is None:
|
||||
raise ValueError("Guardrail is not set")
|
||||
|
||||
from crewai.utilities.events import (
|
||||
GuardrailTaskCompletedEvent,
|
||||
GuardrailTaskStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
GuardrailTaskStartedEvent(
|
||||
guardrail=self.guardrail, retry_count=self.retry_count
|
||||
),
|
||||
)
|
||||
|
||||
if isinstance(self.guardrail, str):
|
||||
from crewai.tasks.guardrail_task import GuardrailTask
|
||||
|
||||
result = GuardrailTask(description=self.guardrail, task=self)(task_output)
|
||||
else:
|
||||
result = self.guardrail(task_output)
|
||||
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
GuardrailTaskCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=self.retry_count,
|
||||
),
|
||||
)
|
||||
return guardrail_result
|
||||
|
||||
def prompt(self) -> str:
|
||||
"""Prompt the task.
|
||||
|
||||
|
||||
154
src/crewai/tasks/guardrail_task.py
Normal file
154
src/crewai/tasks/guardrail_task.py
Normal file
@@ -0,0 +1,154 @@
|
||||
from typing import Any, Tuple
|
||||
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
class GuardrailTask:
|
||||
"""A task that validates the output of another task using generated Python code.
|
||||
|
||||
This class generates and executes Python code to validate task outputs based on
|
||||
specified criteria. It uses an LLM to generate the validation code and provides
|
||||
safety guardrails for code execution.
|
||||
|
||||
Args:
|
||||
description (str): The description of the validation criteria.
|
||||
task (Task, optional): The task whose output needs validation.
|
||||
llm (LLM, optional): The language model to use for code generation.
|
||||
additional_instructions (str, optional): Additional instructions for the guardrail task.
|
||||
|
||||
Raises:
|
||||
ValueError: If no valid LLM is provided.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
task: Task | None = None,
|
||||
llm: LLM | None = None,
|
||||
unsafe_mode: bool = False,
|
||||
additional_instructions: str = "",
|
||||
):
|
||||
self.description = description
|
||||
self.unsafe_mode: bool = unsafe_mode
|
||||
|
||||
fallback_llm: LLM | None = (
|
||||
task.agent.llm
|
||||
if task is not None
|
||||
and hasattr(task, "agent")
|
||||
and task.agent is not None
|
||||
and hasattr(task.agent, "llm")
|
||||
else None
|
||||
)
|
||||
self.llm: LLM | None = llm or fallback_llm
|
||||
|
||||
self.additional_instructions = additional_instructions
|
||||
|
||||
@property
|
||||
def system_instructions(self) -> str:
|
||||
"""System instructions for the LLM code generation.
|
||||
|
||||
Returns:
|
||||
str: Complete system instructions including security constraints.
|
||||
"""
|
||||
security_instructions = (
|
||||
"- DO NOT wrap the output in markdown or use triple backticks. Return only raw Python code."
|
||||
"- DO NOT use `exec`, `eval`, `compile`, `open`, `os`, `subprocess`, `socket`, `shutil`, or any other system-level modules.\n"
|
||||
"- Your code must not perform any file I/O, shell access, or dynamic code execution."
|
||||
)
|
||||
return (
|
||||
"You are a expert Python developer"
|
||||
"You **must strictly** follow the task description, use the provided raw output as the input in your code. "
|
||||
"Your code must:\n"
|
||||
"- Return results with: print((True, data)) on success, or print((False, 'very detailed error message')) on failure. Make sure the final output is beign assined to 'result' variable.\n"
|
||||
"- Use the literal string of the task output (already included in your input) if needed.\n"
|
||||
"- Generate the code **following strictly** the task description.\n"
|
||||
"- Be valid Python 3 — executable as-is.\n"
|
||||
f"{security_instructions}\n"
|
||||
"Additional instructions (do not override the previous instructions):\n"
|
||||
f"{self.additional_instructions}"
|
||||
)
|
||||
|
||||
def user_instructions(self, task_output: TaskOutput) -> str:
|
||||
"""Generates user instructions for the LLM code generation.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The output to be validated.
|
||||
|
||||
Returns:
|
||||
str: Instructions for generating validation code.
|
||||
"""
|
||||
return (
|
||||
"Based on the task description below, generate Python 3 code that validates the task output. \n"
|
||||
"Task description:\n"
|
||||
f"{self.description}\n"
|
||||
"Here is the raw output from the task: \n"
|
||||
f"'{task_output.raw}' \n"
|
||||
"Use this exact string literal inside your generated code (do not reference variables like task_output.raw)."
|
||||
"Now generate Python code that follows the instructions above."
|
||||
)
|
||||
|
||||
def generate_code(self, task_output: TaskOutput) -> str:
|
||||
"""Generates Python code for validating the task output.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The output to be validated.
|
||||
|
||||
Returns:
|
||||
str: Generated Python code for validation.
|
||||
"""
|
||||
if self.llm is None:
|
||||
raise ValueError("Provide a valid LLM to the GuardrailTask")
|
||||
|
||||
response = self.llm.call(
|
||||
messages=[
|
||||
{
|
||||
"role": "system",
|
||||
"content": self.system_instructions,
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": self.user_instructions(task_output=task_output),
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"The following code was generated for the guardrail task:\n{response}\n",
|
||||
color="cyan",
|
||||
)
|
||||
return response
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
|
||||
"""Executes the validation code on the task output.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The output to be validated.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Any]: A tuple containing:
|
||||
- bool: True if validation passed, False otherwise
|
||||
- Any: The validation result or error message
|
||||
"""
|
||||
import ast
|
||||
|
||||
from crewai_tools import CodeInterpreterTool
|
||||
|
||||
code = self.generate_code(task_output)
|
||||
result = CodeInterpreterTool(code=code, unsafe_mode=self.unsafe_mode).run()
|
||||
|
||||
error_messages = [
|
||||
"Something went wrong while running the code",
|
||||
"No result variable found", # when running in unsafe mode, the final output should be stored in the result variable
|
||||
]
|
||||
|
||||
if any(msg in result for msg in error_messages):
|
||||
return False, result
|
||||
|
||||
if isinstance(result, str):
|
||||
result = ast.literal_eval(result)
|
||||
|
||||
return result
|
||||
@@ -9,6 +9,10 @@ from .crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
)
|
||||
from .guardrail_task_events import (
|
||||
GuardrailTaskCompletedEvent,
|
||||
GuardrailTaskStartedEvent,
|
||||
)
|
||||
from .agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
|
||||
@@ -23,6 +23,10 @@ from .flow_events import (
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from .guardrail_task_events import (
|
||||
GuardrailTaskCompletedEvent,
|
||||
GuardrailTaskStartedEvent,
|
||||
)
|
||||
from .llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
@@ -68,4 +72,6 @@ EventTypes = Union[
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
GuardrailTaskStartedEvent,
|
||||
GuardrailTaskCompletedEvent,
|
||||
]
|
||||
|
||||
28
src/crewai/utilities/events/guardrail_task_events.py
Normal file
28
src/crewai/utilities/events/guardrail_task_events.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class GuardrailTaskStartedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task starts
|
||||
|
||||
Attributes:
|
||||
messages: Content can be either a string or a list of dictionaries that support
|
||||
multimodal content (text, images, etc.)
|
||||
"""
|
||||
|
||||
type: str = "guardrail_task_started"
|
||||
guardrail: Union[str, Callable]
|
||||
retry_count: int
|
||||
|
||||
|
||||
class GuardrailTaskCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task completes"""
|
||||
|
||||
type: str = "guardrail_task_completed"
|
||||
success: bool
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
retry_count: int
|
||||
Reference in New Issue
Block a user