refactor: expose guardrail func in a proper utils file

This commit is contained in:
Lucas Gomide
2025-06-10 14:09:55 -03:00
parent 9532d5d430
commit bff7eb74d4
3 changed files with 60 additions and 57 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import inspect
import uuid
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, Self
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
@@ -18,7 +18,7 @@ from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail_result import GuardrailResult
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -198,7 +198,7 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> "LiteAgent":
def ensure_guardrail_is_callable(self) -> Self:
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
@@ -228,9 +228,6 @@ class LiteAgent(FlowTrackable, BaseModel):
if v is None or isinstance(v, str):
return v
if not callable(v):
raise ValueError("Guardrail must be a callable or a string")
# Check function signature
sig = inspect.signature(v)
if len(sig.parameters) != 1:
@@ -352,7 +349,11 @@ class LiteAgent(FlowTrackable, BaseModel):
# Process guardrail if set
if self._guardrail is not None:
guardrail_result = self._process_guardrail(output)
guardrail_result = process_guardrail(
output=output,
guardrail=self._guardrail,
retry_count=self._guardrail_retry_count
)
if not guardrail_result.success:
if self._guardrail_retry_count >= self.guardrail_max_retries:
@@ -584,42 +585,3 @@ class LiteAgent(FlowTrackable, BaseModel):
def _append_message(self, text: str, role: str = "assistant") -> None:
"""Append a message to the message list with the given role."""
self._messages.append(format_message_for_llm(text, role=role))
def _process_guardrail(self, output: LiteAgentOutput) -> GuardrailResult:
"""Process the guardrail for the agent output.
Args:
output: The output to validate with the guardrail
Returns:
GuardrailResult: The result of the guardrail validation
"""
assert self._guardrail is not None
from crewai.utilities.events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
crewai_event_bus.emit(
self,
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self._guardrail_retry_count
),
)
result = self._guardrail(output)
guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit(
self,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=self._guardrail_retry_count,
),
)
return guardrail_result

View File

@@ -35,12 +35,12 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.security import Fingerprint, SecurityConfig
from crewai.utilities.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
from crewai.utilities.converter import Converter, convert_to_model
from crewai.utilities.events import (
TaskCompletedEvent,
@@ -431,7 +431,11 @@ class Task(BaseModel):
)
if self._guardrail:
guardrail_result = self._process_guardrail(task_output)
guardrail_result = process_guardrail(
output=task_output,
guardrail=self._guardrail,
retry_count=self.retry_count
)
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
raise Exception(

View File

@@ -1,15 +1,7 @@
"""
Module for handling task guardrail validation results.
This module provides the GuardrailResult class which standardizes
the way task guardrails return their validation results.
"""
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
from pydantic import BaseModel, field_validator
class GuardrailResult(BaseModel):
"""Result from a task guardrail execution.
@@ -54,3 +46,48 @@ class GuardrailResult(BaseModel):
result=data if success else None,
error=data if not success else None
)
def process_guardrail(output: Any, guardrail: Callable, retry_count: int) -> GuardrailResult:
"""Process the guardrail for the agent output.
Args:
output: The output to validate with the guardrail
Returns:
GuardrailResult: The result of the guardrail validation
"""
from crewai.task import TaskOutput
from crewai.lite_agent import LiteAgentOutput
assert isinstance(output, TaskOutput) or isinstance(output, LiteAgentOutput), "Output must be a TaskOutput or LiteAgentOutput"
assert guardrail is not None
from crewai.utilities.events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
crewai_event_bus.emit(
None,
LLMGuardrailStartedEvent(
guardrail=guardrail, retry_count=retry_count
),
)
result = guardrail(output)
guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit(
None,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=retry_count,
),
)
return guardrail_result