mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
refactor: expose guardrail func in a proper utils file
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, Self
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
|
||||
@@ -18,7 +18,7 @@ from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail_result import GuardrailResult
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -198,7 +198,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> "LiteAgent":
|
||||
def ensure_guardrail_is_callable(self) -> Self:
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
@@ -228,9 +228,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if v is None or isinstance(v, str):
|
||||
return v
|
||||
|
||||
if not callable(v):
|
||||
raise ValueError("Guardrail must be a callable or a string")
|
||||
|
||||
# Check function signature
|
||||
sig = inspect.signature(v)
|
||||
if len(sig.parameters) != 1:
|
||||
@@ -352,7 +349,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
# Process guardrail if set
|
||||
if self._guardrail is not None:
|
||||
guardrail_result = self._process_guardrail(output)
|
||||
guardrail_result = process_guardrail(
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
if self._guardrail_retry_count >= self.guardrail_max_retries:
|
||||
@@ -584,42 +585,3 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
"""Append a message to the message list with the given role."""
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
def _process_guardrail(self, output: LiteAgentOutput) -> GuardrailResult:
|
||||
"""Process the guardrail for the agent output.
|
||||
|
||||
Args:
|
||||
output: The output to validate with the guardrail
|
||||
|
||||
Returns:
|
||||
GuardrailResult: The result of the guardrail validation
|
||||
"""
|
||||
assert self._guardrail is not None
|
||||
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=self._guardrail, retry_count=self._guardrail_retry_count
|
||||
),
|
||||
)
|
||||
|
||||
result = self._guardrail(output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
),
|
||||
)
|
||||
|
||||
return guardrail_result
|
||||
|
||||
@@ -35,12 +35,12 @@ from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.utilities.guardrail_result import GuardrailResult
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED
|
||||
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
from crewai.utilities.events import (
|
||||
TaskCompletedEvent,
|
||||
@@ -431,7 +431,11 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
guardrail_result = self._process_guardrail(task_output)
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self.retry_count
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.max_retries:
|
||||
raise Exception(
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
"""
|
||||
Module for handling task guardrail validation results.
|
||||
|
||||
This module provides the GuardrailResult class which standardizes
|
||||
the way task guardrails return their validation results.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class GuardrailResult(BaseModel):
|
||||
"""Result from a task guardrail execution.
|
||||
|
||||
@@ -54,3 +46,48 @@ class GuardrailResult(BaseModel):
|
||||
result=data if success else None,
|
||||
error=data if not success else None
|
||||
)
|
||||
|
||||
|
||||
def process_guardrail(output: Any, guardrail: Callable, retry_count: int) -> GuardrailResult:
|
||||
"""Process the guardrail for the agent output.
|
||||
|
||||
Args:
|
||||
output: The output to validate with the guardrail
|
||||
|
||||
Returns:
|
||||
GuardrailResult: The result of the guardrail validation
|
||||
"""
|
||||
from crewai.task import TaskOutput
|
||||
from crewai.lite_agent import LiteAgentOutput
|
||||
|
||||
assert isinstance(output, TaskOutput) or isinstance(output, LiteAgentOutput), "Output must be a TaskOutput or LiteAgentOutput"
|
||||
|
||||
assert guardrail is not None
|
||||
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=guardrail, retry_count=retry_count
|
||||
),
|
||||
)
|
||||
|
||||
result = guardrail(output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
LLMGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=retry_count,
|
||||
),
|
||||
)
|
||||
|
||||
return guardrail_result
|
||||
Reference in New Issue
Block a user