LiteAgent w/ Guardrail (#2982)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled

* feat: add guardrail support for Agents when using direct kickoff calls

* refactor: expose guardrail func in a proper utils file

* fix: resolve Self import on python 3.10
This commit is contained in:
Lucas Gomide
2025-06-10 14:32:32 -03:00
committed by GitHub
parent b0d2e9fe31
commit 739eb72fd0
10 changed files with 1976 additions and 74 deletions

View File

@@ -1,6 +1,6 @@
import shutil
import subprocess
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
@@ -155,6 +155,13 @@ class Agent(BaseAgent):
default=None,
description="The Agent's role to be used from your repository.",
)
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output"
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
@model_validator(mode="before")
def validate_from_repository(cls, v):
@@ -780,6 +787,8 @@ class Agent(BaseAgent):
response_format=response_format,
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
guardrail_max_retries=self.guardrail_max_retries,
)
return lite_agent.kickoff(messages)

View File

@@ -1,9 +1,14 @@
import asyncio
import inspect
import uuid
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
try:
from typing import Self
except ImportError:
from typing_extensions import Self
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,6 +23,7 @@ from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail import process_guardrail
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -35,7 +41,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
show_agent_logs,
)
from crewai.utilities.converter import convert_to_model, generate_model_description
from crewai.utilities.converter import generate_model_description
from crewai.utilities.events.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
@@ -146,6 +152,15 @@ class LiteAgent(FlowTrackable, BaseModel):
default=[], description="Callbacks to be used for the agent"
)
# Guardrail Properties
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output"
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
@@ -163,6 +178,9 @@ class LiteAgent(FlowTrackable, BaseModel):
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_guardrail_retry_count: int = PrivateAttr(default=0)
@model_validator(mode="after")
def setup_llm(self):
@@ -184,6 +202,60 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> Self:
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
assert isinstance(self.llm, LLM)
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.llm
)
return self
@field_validator("guardrail", mode="before")
@classmethod
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
"""Validate that the guardrail function has the correct signature.
If v is a callable, validate that it has the correct signature.
If v is a string, return it as is.
Args:
v: The guardrail function to validate or a string describing the guardrail task
Returns:
The validated guardrail function or a string describing the guardrail task
"""
if v is None or isinstance(v, str):
return v
# Check function signature
sig = inspect.signature(v)
if len(sig.parameters) != 1:
raise ValueError(
f"Guardrail function must accept exactly 1 parameter (LiteAgentOutput), "
f"but it accepts {len(sig.parameters)}"
)
# Check return annotation if present
if sig.return_annotation is not sig.empty:
if sig.return_annotation == Tuple[bool, Any]:
return v
origin = get_origin(sig.return_annotation)
args = get_args(sig.return_annotation)
if origin is not tuple or len(args) != 2 or args[0] is not bool:
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
)
return v
@property
def key(self) -> str:
"""Get the unique key for this agent instance."""
@@ -223,54 +295,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Format messages for the LLM
self._messages = self._format_messages(messages)
# Emit event for agent execution start
crewai_event_bus.emit(
self,
event=LiteAgentExecutionStartedEvent(
agent_info=agent_info,
tools=self._parsed_tools,
messages=messages,
),
)
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
color="yellow",
)
# Calculate token usage metrics
usage_metrics = self._token_process.get_summary()
# Create output
output = LiteAgentOutput(
raw=agent_finish.output,
pydantic=formatted_result,
agent_role=self.role,
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
)
# Emit completion event
crewai_event_bus.emit(
self,
event=LiteAgentExecutionCompletedEvent(
agent_info=agent_info,
output=agent_finish.output,
),
)
return output
return self._execute_core(agent_info=agent_info)
except Exception as e:
self._printer.print(
@@ -288,6 +313,94 @@ class LiteAgent(FlowTrackable, BaseModel):
)
raise e
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
# Emit event for agent execution start
crewai_event_bus.emit(
self,
event=LiteAgentExecutionStartedEvent(
agent_info=agent_info,
tools=self._parsed_tools,
messages=self._messages,
),
)
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
color="yellow",
)
# Calculate token usage metrics
usage_metrics = self._token_process.get_summary()
# Create output
output = LiteAgentOutput(
raw=agent_finish.output,
pydantic=formatted_result,
agent_role=self.role,
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
)
# Process guardrail if set
if self._guardrail is not None:
guardrail_result = process_guardrail(
output=output,
guardrail=self._guardrail,
retry_count=self._guardrail_retry_count
)
if not guardrail_result.success:
if self._guardrail_retry_count >= self.guardrail_max_retries:
raise Exception(
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self._guardrail_retry_count += 1
if self.verbose:
self._printer.print(
f"Guardrail failed. Retrying ({self._guardrail_retry_count}/{self.guardrail_max_retries})..."
f"\n{guardrail_result.error}"
)
self._messages.append({
"role": "user",
"content": guardrail_result.error or "Guardrail validation failed"
})
return self._execute_core(agent_info=agent_info)
# Apply guardrail result if available
if guardrail_result.result is not None:
if isinstance(guardrail_result.result, str):
output.raw = guardrail_result.result
elif isinstance(guardrail_result.result, BaseModel):
output.pydantic = guardrail_result.result
usage_metrics = self._token_process.get_summary()
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
# Emit completion event
crewai_event_bus.emit(
self,
event=LiteAgentExecutionCompletedEvent(
agent_info=agent_info,
output=agent_finish.output,
),
)
return output
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
) -> LiteAgentOutput:

View File

@@ -35,12 +35,12 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
from crewai.utilities.converter import Converter, convert_to_model
from crewai.utilities.events import (
TaskCompletedEvent,
@@ -431,7 +431,11 @@ class Task(BaseModel):
)
if self._guardrail:
guardrail_result = self._process_guardrail(task_output)
guardrail_result = process_guardrail(
output=task_output,
guardrail=self._guardrail,
retry_count=self.retry_count
)
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
raise Exception(
@@ -527,10 +531,10 @@ class Task(BaseModel):
def prompt(self) -> str:
"""Generates the task prompt with optional markdown formatting.
When the markdown attribute is True, instructions for formatting the
response in Markdown syntax will be added to the prompt.
Returns:
str: The formatted prompt string containing the task description,
expected output, and optional markdown formatting instructions.
@@ -541,7 +545,7 @@ class Task(BaseModel):
expected_output=self.expected_output
)
tasks_slices = [self.description, output]
if self.markdown:
markdown_instruction = """Your final answer MUST be formatted in Markdown syntax.
Follow these guidelines:

View File

@@ -64,7 +64,7 @@ class BaseTool(BaseModel, ABC):
},
},
)
@field_validator("max_usage_count", mode="before")
@classmethod
def validate_max_usage_count(cls, v: int | None) -> int | None:
@@ -88,11 +88,11 @@ class BaseTool(BaseModel, ABC):
# If _run is async, we safely run it
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@@ -279,7 +279,7 @@ def to_langchain(
def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
"""
Decorator to create a tool from a function.
Args:
*args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.

View File

@@ -1,15 +1,7 @@
"""
Module for handling task guardrail validation results.
This module provides the GuardrailResult class which standardizes
the way task guardrails return their validation results.
"""
from typing import Any, Optional, Tuple, Union
from typing import Any, Callable, Optional, Tuple, Union
from pydantic import BaseModel, field_validator
class GuardrailResult(BaseModel):
"""Result from a task guardrail execution.
@@ -54,3 +46,48 @@ class GuardrailResult(BaseModel):
result=data if success else None,
error=data if not success else None
)
def process_guardrail(output: Any, guardrail: Callable, retry_count: int) -> GuardrailResult:
"""Process the guardrail for the agent output.
Args:
output: The output to validate with the guardrail
Returns:
GuardrailResult: The result of the guardrail validation
"""
from crewai.task import TaskOutput
from crewai.lite_agent import LiteAgentOutput
assert isinstance(output, TaskOutput) or isinstance(output, LiteAgentOutput), "Output must be a TaskOutput or LiteAgentOutput"
assert guardrail is not None
from crewai.utilities.events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
crewai_event_bus.emit(
None,
LLMGuardrailStartedEvent(
guardrail=guardrail, retry_count=retry_count
),
)
result = guardrail(output)
guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit(
None,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=retry_count,
),
)
return guardrail_result