feat: add guardrail support for Agents when using direct kickoff calls

This commit is contained in:
Lucas Gomide
2025-06-10 12:08:17 -03:00
parent 5b740467cb
commit bacc6fd862
10 changed files with 1958 additions and 64 deletions

View File

@@ -1,9 +1,9 @@
import asyncio
import inspect
import uuid
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
@@ -18,6 +18,7 @@ from crewai.llm import LLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities import I18N
from crewai.utilities.guardrail_result import GuardrailResult
from crewai.utilities.agent_utils import (
enforce_rpm_limit,
format_message_for_llm,
@@ -35,7 +36,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
show_agent_logs,
)
from crewai.utilities.converter import convert_to_model, generate_model_description
from crewai.utilities.converter import generate_model_description
from crewai.utilities.events.agent_events import (
LiteAgentExecutionCompletedEvent,
LiteAgentExecutionErrorEvent,
@@ -146,6 +147,15 @@ class LiteAgent(FlowTrackable, BaseModel):
default=[], description="Callbacks to be used for the agent"
)
# Guardrail Properties
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
default=None,
description="Function or string description of a guardrail to validate agent output"
)
guardrail_max_retries: int = Field(
default=3, description="Maximum number of retries when guardrail fails"
)
# State and Results
tools_results: List[Dict[str, Any]] = Field(
default=[], description="Results of the tools used by the agent."
@@ -163,6 +173,9 @@ class LiteAgent(FlowTrackable, BaseModel):
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
_iterations: int = PrivateAttr(default=0)
_printer: Printer = PrivateAttr(default_factory=Printer)
_guardrail: Optional[Callable] = PrivateAttr(default=None)
_guardrail_retry_count: int = PrivateAttr(default=0)
@model_validator(mode="after")
def setup_llm(self):
@@ -184,6 +197,63 @@ class LiteAgent(FlowTrackable, BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrail_is_callable(self) -> "LiteAgent":
if callable(self.guardrail):
self._guardrail = self.guardrail
elif isinstance(self.guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
assert isinstance(self.llm, LLM)
self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.llm
)
return self
@field_validator("guardrail", mode="before")
@classmethod
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
"""Validate that the guardrail function has the correct signature.
If v is a callable, validate that it has the correct signature.
If v is a string, return it as is.
Args:
v: The guardrail function to validate or a string describing the guardrail task
Returns:
The validated guardrail function or a string describing the guardrail task
"""
if v is None or isinstance(v, str):
return v
if not callable(v):
raise ValueError("Guardrail must be a callable or a string")
# Check function signature
sig = inspect.signature(v)
if len(sig.parameters) != 1:
raise ValueError(
f"Guardrail function must accept exactly 1 parameter (LiteAgentOutput), "
f"but it accepts {len(sig.parameters)}"
)
# Check return annotation if present
if sig.return_annotation is not sig.empty:
if sig.return_annotation == Tuple[bool, Any]:
return v
origin = get_origin(sig.return_annotation)
args = get_args(sig.return_annotation)
if origin is not tuple or len(args) != 2 or args[0] is not bool:
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
)
return v
@property
def key(self) -> str:
"""Get the unique key for this agent instance."""
@@ -223,54 +293,7 @@ class LiteAgent(FlowTrackable, BaseModel):
# Format messages for the LLM
self._messages = self._format_messages(messages)
# Emit event for agent execution start
crewai_event_bus.emit(
self,
event=LiteAgentExecutionStartedEvent(
agent_info=agent_info,
tools=self._parsed_tools,
messages=messages,
),
)
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
color="yellow",
)
# Calculate token usage metrics
usage_metrics = self._token_process.get_summary()
# Create output
output = LiteAgentOutput(
raw=agent_finish.output,
pydantic=formatted_result,
agent_role=self.role,
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
)
# Emit completion event
crewai_event_bus.emit(
self,
event=LiteAgentExecutionCompletedEvent(
agent_info=agent_info,
output=agent_finish.output,
),
)
return output
return self._execute_core(agent_info=agent_info)
except Exception as e:
self._printer.print(
@@ -288,6 +311,90 @@ class LiteAgent(FlowTrackable, BaseModel):
)
raise e
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
# Emit event for agent execution start
crewai_event_bus.emit(
self,
event=LiteAgentExecutionStartedEvent(
agent_info=agent_info,
tools=self._parsed_tools,
messages=self._messages,
),
)
# Execute the agent using invoke loop
agent_finish = self._invoke_loop()
formatted_result: Optional[BaseModel] = None
if self.response_format:
try:
# Cast to BaseModel to ensure type safety
result = self.response_format.model_validate_json(
agent_finish.output
)
if isinstance(result, BaseModel):
formatted_result = result
except Exception as e:
self._printer.print(
content=f"Failed to parse output into response format: {str(e)}",
color="yellow",
)
# Calculate token usage metrics
usage_metrics = self._token_process.get_summary()
# Create output
output = LiteAgentOutput(
raw=agent_finish.output,
pydantic=formatted_result,
agent_role=self.role,
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
)
# Process guardrail if set
if self._guardrail is not None:
guardrail_result = self._process_guardrail(output)
if not guardrail_result.success:
if self._guardrail_retry_count >= self.guardrail_max_retries:
raise Exception(
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self._guardrail_retry_count += 1
if self.verbose:
self._printer.print(
f"Guardrail failed. Retrying ({self._guardrail_retry_count}/{self.guardrail_max_retries})..."
f"\n{guardrail_result.error}"
)
self._messages.append({
"role": "user",
"content": guardrail_result.error or "Guardrail validation failed"
})
return self._execute_core(agent_info=agent_info)
# Apply guardrail result if available
if guardrail_result.result is not None:
if isinstance(guardrail_result.result, str):
output.raw = guardrail_result.result
elif isinstance(guardrail_result.result, BaseModel):
output.pydantic = guardrail_result.result
usage_metrics = self._token_process.get_summary()
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
# Emit completion event
crewai_event_bus.emit(
self,
event=LiteAgentExecutionCompletedEvent(
agent_info=agent_info,
output=agent_finish.output,
),
)
return output
async def kickoff_async(
self, messages: Union[str, List[Dict[str, str]]]
) -> LiteAgentOutput:
@@ -477,3 +584,42 @@ class LiteAgent(FlowTrackable, BaseModel):
def _append_message(self, text: str, role: str = "assistant") -> None:
"""Append a message to the message list with the given role."""
self._messages.append(format_message_for_llm(text, role=role))
def _process_guardrail(self, output: LiteAgentOutput) -> GuardrailResult:
"""Process the guardrail for the agent output.
Args:
output: The output to validate with the guardrail
Returns:
GuardrailResult: The result of the guardrail validation
"""
assert self._guardrail is not None
from crewai.utilities.events import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
crewai_event_bus.emit(
self,
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self._guardrail_retry_count
),
)
result = self._guardrail(output)
guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit(
self,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=self._guardrail_retry_count,
),
)
return guardrail_result