mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
feat: add guardrail support for Agents when using direct kickoff calls
This commit is contained in:
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -18,6 +18,7 @@ from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail_result import GuardrailResult
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -35,7 +36,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
show_agent_logs,
|
||||
)
|
||||
from crewai.utilities.converter import convert_to_model, generate_model_description
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.events.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
@@ -146,6 +147,15 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
@@ -163,6 +173,9 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
@@ -184,6 +197,63 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> "LiteAgent":
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
assert isinstance(self.llm, LLM)
|
||||
|
||||
self._guardrail = LLMGuardrail(
|
||||
description=self.guardrail, llm=self.llm
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
If v is a string, return it as is.
|
||||
|
||||
Args:
|
||||
v: The guardrail function to validate or a string describing the guardrail task
|
||||
|
||||
Returns:
|
||||
The validated guardrail function or a string describing the guardrail task
|
||||
"""
|
||||
if v is None or isinstance(v, str):
|
||||
return v
|
||||
|
||||
if not callable(v):
|
||||
raise ValueError("Guardrail must be a callable or a string")
|
||||
|
||||
# Check function signature
|
||||
sig = inspect.signature(v)
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError(
|
||||
f"Guardrail function must accept exactly 1 parameter (LiteAgentOutput), "
|
||||
f"but it accepts {len(sig.parameters)}"
|
||||
)
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
args = get_args(sig.return_annotation)
|
||||
|
||||
if origin is not tuple or len(args) != 2 or args[0] is not bool:
|
||||
raise ValueError(
|
||||
"If return type is annotated, it must be Tuple[bool, Any]"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Get the unique key for this agent instance."""
|
||||
@@ -223,54 +293,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
# Format messages for the LLM
|
||||
self._messages = self._format_messages(messages)
|
||||
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=self._parsed_tools,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
output = LiteAgentOutput(
|
||||
raw=agent_finish.output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=agent_finish.output,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
@@ -288,6 +311,90 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=self._parsed_tools,
|
||||
messages=self._messages,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
output = LiteAgentOutput(
|
||||
raw=agent_finish.output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
)
|
||||
|
||||
# Process guardrail if set
|
||||
if self._guardrail is not None:
|
||||
guardrail_result = self._process_guardrail(output)
|
||||
|
||||
if not guardrail_result.success:
|
||||
if self._guardrail_retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
self._guardrail_retry_count += 1
|
||||
if self.verbose:
|
||||
self._printer.print(
|
||||
f"Guardrail failed. Retrying ({self._guardrail_retry_count}/{self.guardrail_max_retries})..."
|
||||
f"\n{guardrail_result.error}"
|
||||
)
|
||||
|
||||
self._messages.append({
|
||||
"role": "user",
|
||||
"content": guardrail_result.error or "Guardrail validation failed"
|
||||
})
|
||||
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
# Apply guardrail result if available
|
||||
if guardrail_result.result is not None:
|
||||
if isinstance(guardrail_result.result, str):
|
||||
output.raw = guardrail_result.result
|
||||
elif isinstance(guardrail_result.result, BaseModel):
|
||||
output.pydantic = guardrail_result.result
|
||||
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
|
||||
|
||||
# Emit completion event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=agent_finish.output,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> LiteAgentOutput:
|
||||
@@ -477,3 +584,42 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
"""Append a message to the message list with the given role."""
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
def _process_guardrail(self, output: LiteAgentOutput) -> GuardrailResult:
|
||||
"""Process the guardrail for the agent output.
|
||||
|
||||
Args:
|
||||
output: The output to validate with the guardrail
|
||||
|
||||
Returns:
|
||||
GuardrailResult: The result of the guardrail validation
|
||||
"""
|
||||
assert self._guardrail is not None
|
||||
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=self._guardrail, retry_count=self._guardrail_retry_count
|
||||
),
|
||||
)
|
||||
|
||||
result = self._guardrail(output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
LLMGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=self._guardrail_retry_count,
|
||||
),
|
||||
)
|
||||
|
||||
return guardrail_result
|
||||
|
||||
Reference in New Issue
Block a user