mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 21:38:14 +00:00
Supporting no-code Guardrail creation (#2636)
* feat: support to define a guardrail task no-code * feat: add auto-discovery for Guardrail code execution mode * feat: handle malformed or invalid response from CodeInterpreterTool * feat: allow to set unsafe_mode from Guardrail task * feat: renaming GuardrailTask to TaskGuardrail * feat: ensure guardrail is callable while initializing Task * feat: remove Docker availability check from TaskGuardrail The CodeInterpreterTool already ensures compliance with this requirement. * refactor: replace if/raise with assert For this use case `assert` is more appropriate choice * test: remove useless or duplicated test * fix: attempt to fix type-checker * feat: support to define a task guardrail using YAML config * refactor: simplify TaskGuardrail to use LLM for validation, no code generation * docs: update TaskGuardrail doc strings * refactor: drop task paramenter from TaskGuardrail This parameter was used to get the model from the `task.agent` which is a quite bit redudant since we could propagate the llm directly
This commit is contained in:
@@ -483,6 +483,7 @@ class LLM(BaseLLM):
|
||||
full_response += chunk_content
|
||||
|
||||
# Emit the chunk event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(chunk=chunk_content),
|
||||
@@ -611,6 +612,7 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
# Emit failed event and re-raise the exception
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
@@ -633,7 +635,7 @@ class LLM(BaseLLM):
|
||||
current_tool_accumulator.function.arguments += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -806,6 +808,7 @@ class LLM(BaseLLM):
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=f"Tool execution error: {str(e)}"),
|
||||
@@ -843,6 +846,7 @@ class LLM(BaseLLM):
|
||||
LLMContextLengthExceededException: If input exceeds model's context limit
|
||||
"""
|
||||
# --- 1) Emit call started event
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(
|
||||
@@ -891,6 +895,7 @@ class LLM(BaseLLM):
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
@@ -905,6 +910,7 @@ class LLM(BaseLLM):
|
||||
response (str): The response from the LLM call.
|
||||
call_type (str): The type of call, either "tool_call" or "llm_call".
|
||||
"""
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(response=response, call_type=call_type),
|
||||
|
||||
@@ -246,6 +246,9 @@ def CrewBase(cls: T) -> T:
|
||||
callback_functions[callback]() for callback in callbacks
|
||||
]
|
||||
|
||||
if guardrail := task_info.get("guardrail"):
|
||||
self.tasks_config[task_name]["guardrail"] = guardrail
|
||||
|
||||
# Include base class (qual)name in the wrapper class (qual)name.
|
||||
WrappedClass.__name__ = CrewBase.__name__ + "(" + cls.__name__ + ")"
|
||||
WrappedClass.__qualname__ = CrewBase.__qualname__ + "(" + cls.__name__ + ")"
|
||||
|
||||
@@ -140,9 +140,9 @@ class Task(BaseModel):
|
||||
default=None,
|
||||
)
|
||||
processed_by_agents: Set[str] = Field(default_factory=set)
|
||||
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
|
||||
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function to validate task output before proceeding to next task",
|
||||
description="Function or string description of a guardrail to validate task output before proceeding to next task",
|
||||
)
|
||||
max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
@@ -157,8 +157,12 @@ class Task(BaseModel):
|
||||
|
||||
@field_validator("guardrail")
|
||||
@classmethod
|
||||
def validate_guardrail_function(cls, v: Optional[Callable]) -> Optional[Callable]:
|
||||
"""Validate that the guardrail function has the correct signature and behavior.
|
||||
def validate_guardrail_function(
|
||||
cls, v: Optional[str | Callable]
|
||||
) -> Optional[str | Callable]:
|
||||
"""
|
||||
If v is a callable, validate that the guardrail function has the correct signature and behavior.
|
||||
If v is a string, return it as is.
|
||||
|
||||
While type hints provide static checking, this validator ensures runtime safety by:
|
||||
1. Verifying the function accepts exactly one parameter (the TaskOutput)
|
||||
@@ -171,16 +175,16 @@ class Task(BaseModel):
|
||||
- Clear error messages help users debug guardrail implementation issues
|
||||
|
||||
Args:
|
||||
v: The guardrail function to validate
|
||||
v: The guardrail function to validate or a string describing the guardrail task
|
||||
|
||||
Returns:
|
||||
The validated guardrail function
|
||||
The validated guardrail function or a string describing the guardrail task
|
||||
|
||||
Raises:
|
||||
ValueError: If the function signature is invalid or return annotation
|
||||
doesn't match Tuple[bool, Any]
|
||||
"""
|
||||
if v is not None:
|
||||
if v is not None and callable(v):
|
||||
sig = inspect.signature(v)
|
||||
positional_args = [
|
||||
param
|
||||
@@ -211,6 +215,7 @@ class Task(BaseModel):
|
||||
)
|
||||
return v
|
||||
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_original_description: Optional[str] = PrivateAttr(default=None)
|
||||
_original_expected_output: Optional[str] = PrivateAttr(default=None)
|
||||
_original_output_file: Optional[str] = PrivateAttr(default=None)
|
||||
@@ -231,6 +236,20 @@ class Task(BaseModel):
|
||||
)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> "Task":
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
||||
|
||||
assert self.agent is not None
|
||||
self._guardrail = TaskGuardrail(
|
||||
description=self.guardrail, llm=self.agent.llm
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None:
|
||||
@@ -407,10 +426,8 @@ class Task(BaseModel):
|
||||
output_format=self._get_output_format(),
|
||||
)
|
||||
|
||||
if self.guardrail:
|
||||
guardrail_result = GuardrailResult.from_tuple(
|
||||
self.guardrail(task_output)
|
||||
)
|
||||
if self._guardrail:
|
||||
guardrail_result = self._process_guardrail(task_output)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.max_retries:
|
||||
raise Exception(
|
||||
@@ -464,13 +481,46 @@ class Task(BaseModel):
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output, task=self))
|
||||
crewai_event_bus.emit(
|
||||
self, TaskCompletedEvent(output=task_output, task=self)
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
|
||||
assert self._guardrail is not None
|
||||
|
||||
from crewai.utilities.events import (
|
||||
TaskGuardrailCompletedEvent,
|
||||
TaskGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
result = self._guardrail(task_output)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskGuardrailStartedEvent(
|
||||
guardrail=self._guardrail, retry_count=self.retry_count
|
||||
),
|
||||
)
|
||||
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=self.retry_count,
|
||||
),
|
||||
)
|
||||
return guardrail_result
|
||||
|
||||
def prompt(self) -> str:
|
||||
"""Prompt the task.
|
||||
|
||||
|
||||
92
src/crewai/tasks/task_guardrail.py
Normal file
92
src/crewai/tasks/task_guardrail.py
Normal file
@@ -0,0 +1,92 @@
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.agent import Agent, LiteAgentOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.task import Task
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
|
||||
|
||||
class TaskGuardrailResult(BaseModel):
|
||||
valid: bool = Field(
|
||||
description="Whether the task output complies with the guardrail"
|
||||
)
|
||||
feedback: str | None = Field(
|
||||
description="A feedback about the task output if it is not valid",
|
||||
default=None,
|
||||
)
|
||||
|
||||
|
||||
class TaskGuardrail:
|
||||
"""It validates the output of another task using an LLM.
|
||||
|
||||
This class is used to validate the output from a Task based on specified criteria.
|
||||
It uses an LLM to validate the output and provides a feedback if the output is not valid.
|
||||
|
||||
Args:
|
||||
description (str): The description of the validation criteria.
|
||||
llm (LLM, optional): The language model to use for code generation.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
description: str,
|
||||
llm: LLM,
|
||||
):
|
||||
self.description = description
|
||||
|
||||
self.llm: LLM = llm
|
||||
|
||||
def _validate_output(self, task_output: TaskOutput) -> LiteAgentOutput:
|
||||
agent = Agent(
|
||||
role="Guardrail Agent",
|
||||
goal="Validate the output of the task",
|
||||
backstory="You are a expert at validating the output of a task. By providing effective feedback if the output is not valid.",
|
||||
llm=self.llm,
|
||||
)
|
||||
|
||||
query = f"""
|
||||
Ensure the following task result complies with the given guardrail.
|
||||
|
||||
Task result:
|
||||
{task_output.raw}
|
||||
|
||||
Guardrail:
|
||||
{self.description}
|
||||
|
||||
Your task:
|
||||
- Confirm if the Task result complies with the guardrail.
|
||||
- If not, provide clear feedback explaining what is wrong (e.g., by how much it violates the rule, or what specific part fails).
|
||||
- Focus only on identifying issues — do not propose corrections.
|
||||
- If the Task result complies with the guardrail, saying that is valid
|
||||
"""
|
||||
|
||||
result = agent.kickoff(query, response_format=TaskGuardrailResult)
|
||||
|
||||
return result
|
||||
|
||||
def __call__(self, task_output: TaskOutput) -> Tuple[bool, Any]:
|
||||
"""Validates the output of a task based on specified criteria.
|
||||
|
||||
Args:
|
||||
task_output (TaskOutput): The output to be validated.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Any]: A tuple containing:
|
||||
- bool: True if validation passed, False otherwise
|
||||
- Any: The validation result or error message
|
||||
"""
|
||||
|
||||
try:
|
||||
result = self._validate_output(task_output)
|
||||
assert isinstance(
|
||||
result.pydantic, TaskGuardrailResult
|
||||
), "The guardrail result is not a valid pydantic model"
|
||||
|
||||
if result.pydantic.valid:
|
||||
return True, task_output.raw
|
||||
else:
|
||||
return False, result.pydantic.feedback
|
||||
except Exception as e:
|
||||
return False, f"Error while validating the task output: {str(e)}"
|
||||
@@ -9,6 +9,10 @@ from .crew_events import (
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
)
|
||||
from .task_guardrail_events import (
|
||||
TaskGuardrailCompletedEvent,
|
||||
TaskGuardrailStartedEvent,
|
||||
)
|
||||
from .agent_events import (
|
||||
AgentExecutionStartedEvent,
|
||||
AgentExecutionCompletedEvent,
|
||||
|
||||
@@ -34,6 +34,10 @@ from .task_events import (
|
||||
TaskFailedEvent,
|
||||
TaskStartedEvent,
|
||||
)
|
||||
from .task_guardrail_events import (
|
||||
TaskGuardrailCompletedEvent,
|
||||
TaskGuardrailStartedEvent,
|
||||
)
|
||||
from .tool_usage_events import (
|
||||
ToolUsageErrorEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
@@ -68,4 +72,6 @@ EventTypes = Union[
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMStreamChunkEvent,
|
||||
TaskGuardrailStartedEvent,
|
||||
TaskGuardrailCompletedEvent,
|
||||
]
|
||||
|
||||
38
src/crewai/utilities/events/task_guardrail_events.py
Normal file
38
src/crewai/utilities/events/task_guardrail_events.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
|
||||
|
||||
class TaskGuardrailStartedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task starts
|
||||
|
||||
Attributes:
|
||||
guardrail: The guardrail callable or TaskGuardrail instance
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
|
||||
type: str = "task_guardrail_started"
|
||||
guardrail: Union[str, Callable]
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from inspect import getsource
|
||||
|
||||
from crewai.tasks.task_guardrail import TaskGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
if isinstance(self.guardrail, TaskGuardrail):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
class TaskGuardrailCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task completes"""
|
||||
|
||||
type: str = "task_guardrail_completed"
|
||||
success: bool
|
||||
result: Any
|
||||
error: Optional[str] = None
|
||||
retry_count: int
|
||||
Reference in New Issue
Block a user