mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 09:08:31 +00:00
LiteAgent w/ Guardrail (#2982)
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* feat: add guardrail support for Agents when using direct kickoff calls * refactor: expose guardrail func in a proper utils file * fix: resolve Self import on python 3.10
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import shutil
|
||||
import subprocess
|
||||
from typing import Any, Dict, List, Literal, Optional, Sequence, Type, Union
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Tuple, Type, Union
|
||||
|
||||
from pydantic import Field, InstanceOf, PrivateAttr, model_validator
|
||||
|
||||
@@ -155,6 +155,13 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
guardrail: Optional[Union[Callable[[Any], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
@@ -780,6 +787,8 @@ class Agent(BaseAgent):
|
||||
response_format=response_format,
|
||||
i18n=self.i18n,
|
||||
original_agent=self,
|
||||
guardrail=self.guardrail,
|
||||
guardrail_max_retries=self.guardrail_max_retries,
|
||||
)
|
||||
|
||||
return lite_agent.kickoff(messages)
|
||||
|
||||
@@ -1,9 +1,14 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, Union, cast
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
try:
|
||||
from typing import Self
|
||||
except ImportError:
|
||||
from typing_extensions import Self
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
@@ -18,6 +23,7 @@ from crewai.llm import LLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities import I18N
|
||||
from crewai.utilities.guardrail import process_guardrail
|
||||
from crewai.utilities.agent_utils import (
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
@@ -35,7 +41,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
show_agent_logs,
|
||||
)
|
||||
from crewai.utilities.converter import convert_to_model, generate_model_description
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.events.agent_events import (
|
||||
LiteAgentExecutionCompletedEvent,
|
||||
LiteAgentExecutionErrorEvent,
|
||||
@@ -146,6 +152,15 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
default=[], description="Callbacks to be used for the agent"
|
||||
)
|
||||
|
||||
# Guardrail Properties
|
||||
guardrail: Optional[Union[Callable[[LiteAgentOutput], Tuple[bool, Any]], str]] = Field(
|
||||
default=None,
|
||||
description="Function or string description of a guardrail to validate agent output"
|
||||
)
|
||||
guardrail_max_retries: int = Field(
|
||||
default=3, description="Maximum number of retries when guardrail fails"
|
||||
)
|
||||
|
||||
# State and Results
|
||||
tools_results: List[Dict[str, Any]] = Field(
|
||||
default=[], description="Results of the tools used by the agent."
|
||||
@@ -163,6 +178,9 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
@@ -184,6 +202,60 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def ensure_guardrail_is_callable(self) -> Self:
|
||||
if callable(self.guardrail):
|
||||
self._guardrail = self.guardrail
|
||||
elif isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
assert isinstance(self.llm, LLM)
|
||||
|
||||
self._guardrail = LLMGuardrail(
|
||||
description=self.guardrail, llm=self.llm
|
||||
)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(cls, v: Optional[Union[Callable, str]]) -> Optional[Union[Callable, str]]:
|
||||
"""Validate that the guardrail function has the correct signature.
|
||||
|
||||
If v is a callable, validate that it has the correct signature.
|
||||
If v is a string, return it as is.
|
||||
|
||||
Args:
|
||||
v: The guardrail function to validate or a string describing the guardrail task
|
||||
|
||||
Returns:
|
||||
The validated guardrail function or a string describing the guardrail task
|
||||
"""
|
||||
if v is None or isinstance(v, str):
|
||||
return v
|
||||
|
||||
# Check function signature
|
||||
sig = inspect.signature(v)
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError(
|
||||
f"Guardrail function must accept exactly 1 parameter (LiteAgentOutput), "
|
||||
f"but it accepts {len(sig.parameters)}"
|
||||
)
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == Tuple[bool, Any]:
|
||||
return v
|
||||
|
||||
origin = get_origin(sig.return_annotation)
|
||||
args = get_args(sig.return_annotation)
|
||||
|
||||
if origin is not tuple or len(args) != 2 or args[0] is not bool:
|
||||
raise ValueError(
|
||||
"If return type is annotated, it must be Tuple[bool, Any]"
|
||||
)
|
||||
|
||||
return v
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Get the unique key for this agent instance."""
|
||||
@@ -223,54 +295,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
# Format messages for the LLM
|
||||
self._messages = self._format_messages(messages)
|
||||
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=self._parsed_tools,
|
||||
messages=messages,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
output = LiteAgentOutput(
|
||||
raw=agent_finish.output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
)
|
||||
|
||||
# Emit completion event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=agent_finish.output,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
@@ -288,6 +313,94 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
def _execute_core(self, agent_info: Dict[str, Any]) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
agent_info=agent_info,
|
||||
tools=self._parsed_tools,
|
||||
messages=self._messages,
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
agent_finish = self._invoke_loop()
|
||||
formatted_result: Optional[BaseModel] = None
|
||||
if self.response_format:
|
||||
try:
|
||||
# Cast to BaseModel to ensure type safety
|
||||
result = self.response_format.model_validate_json(
|
||||
agent_finish.output
|
||||
)
|
||||
if isinstance(result, BaseModel):
|
||||
formatted_result = result
|
||||
except Exception as e:
|
||||
self._printer.print(
|
||||
content=f"Failed to parse output into response format: {str(e)}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
output = LiteAgentOutput(
|
||||
raw=agent_finish.output,
|
||||
pydantic=formatted_result,
|
||||
agent_role=self.role,
|
||||
usage_metrics=usage_metrics.model_dump() if usage_metrics else None,
|
||||
)
|
||||
|
||||
# Process guardrail if set
|
||||
if self._guardrail is not None:
|
||||
guardrail_result = process_guardrail(
|
||||
output=output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self._guardrail_retry_count
|
||||
)
|
||||
|
||||
if not guardrail_result.success:
|
||||
if self._guardrail_retry_count >= self.guardrail_max_retries:
|
||||
raise Exception(
|
||||
f"Agent's guardrail failed validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
self._guardrail_retry_count += 1
|
||||
if self.verbose:
|
||||
self._printer.print(
|
||||
f"Guardrail failed. Retrying ({self._guardrail_retry_count}/{self.guardrail_max_retries})..."
|
||||
f"\n{guardrail_result.error}"
|
||||
)
|
||||
|
||||
self._messages.append({
|
||||
"role": "user",
|
||||
"content": guardrail_result.error or "Guardrail validation failed"
|
||||
})
|
||||
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
# Apply guardrail result if available
|
||||
if guardrail_result.result is not None:
|
||||
if isinstance(guardrail_result.result, str):
|
||||
output.raw = guardrail_result.result
|
||||
elif isinstance(guardrail_result.result, BaseModel):
|
||||
output.pydantic = guardrail_result.result
|
||||
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
|
||||
|
||||
# Emit completion event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
agent_info=agent_info,
|
||||
output=agent_finish.output,
|
||||
),
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
async def kickoff_async(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> LiteAgentOutput:
|
||||
|
||||
@@ -35,12 +35,12 @@ from pydantic_core import PydanticCustomError
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.security import Fingerprint, SecurityConfig
|
||||
from crewai.tasks.guardrail_result import GuardrailResult
|
||||
from crewai.tasks.output_format import OutputFormat
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.config import process_config
|
||||
from crewai.utilities.constants import NOT_SPECIFIED
|
||||
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
|
||||
from crewai.utilities.converter import Converter, convert_to_model
|
||||
from crewai.utilities.events import (
|
||||
TaskCompletedEvent,
|
||||
@@ -431,7 +431,11 @@ class Task(BaseModel):
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
guardrail_result = self._process_guardrail(task_output)
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=self._guardrail,
|
||||
retry_count=self.retry_count
|
||||
)
|
||||
if not guardrail_result.success:
|
||||
if self.retry_count >= self.max_retries:
|
||||
raise Exception(
|
||||
@@ -527,10 +531,10 @@ class Task(BaseModel):
|
||||
|
||||
def prompt(self) -> str:
|
||||
"""Generates the task prompt with optional markdown formatting.
|
||||
|
||||
|
||||
When the markdown attribute is True, instructions for formatting the
|
||||
response in Markdown syntax will be added to the prompt.
|
||||
|
||||
|
||||
Returns:
|
||||
str: The formatted prompt string containing the task description,
|
||||
expected output, and optional markdown formatting instructions.
|
||||
@@ -541,7 +545,7 @@ class Task(BaseModel):
|
||||
expected_output=self.expected_output
|
||||
)
|
||||
tasks_slices = [self.description, output]
|
||||
|
||||
|
||||
if self.markdown:
|
||||
markdown_instruction = """Your final answer MUST be formatted in Markdown syntax.
|
||||
Follow these guidelines:
|
||||
|
||||
@@ -64,7 +64,7 @@ class BaseTool(BaseModel, ABC):
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@field_validator("max_usage_count", mode="before")
|
||||
@classmethod
|
||||
def validate_max_usage_count(cls, v: int | None) -> int | None:
|
||||
@@ -88,11 +88,11 @@ class BaseTool(BaseModel, ABC):
|
||||
# If _run is async, we safely run it
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
|
||||
self.current_usage_count += 1
|
||||
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def reset_usage_count(self) -> None:
|
||||
"""Reset the current usage count to zero."""
|
||||
self.current_usage_count = 0
|
||||
@@ -279,7 +279,7 @@ def to_langchain(
|
||||
def tool(*args, result_as_answer: bool = False, max_usage_count: int | None = None) -> Callable:
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
|
||||
|
||||
Args:
|
||||
*args: Positional arguments, either the function to decorate or the tool name.
|
||||
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
"""
|
||||
Module for handling task guardrail validation results.
|
||||
|
||||
This module provides the GuardrailResult class which standardizes
|
||||
the way task guardrails return their validation results.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional, Tuple, Union
|
||||
from typing import Any, Callable, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
|
||||
|
||||
class GuardrailResult(BaseModel):
|
||||
"""Result from a task guardrail execution.
|
||||
|
||||
@@ -54,3 +46,48 @@ class GuardrailResult(BaseModel):
|
||||
result=data if success else None,
|
||||
error=data if not success else None
|
||||
)
|
||||
|
||||
|
||||
def process_guardrail(output: Any, guardrail: Callable, retry_count: int) -> GuardrailResult:
|
||||
"""Process the guardrail for the agent output.
|
||||
|
||||
Args:
|
||||
output: The output to validate with the guardrail
|
||||
|
||||
Returns:
|
||||
GuardrailResult: The result of the guardrail validation
|
||||
"""
|
||||
from crewai.task import TaskOutput
|
||||
from crewai.lite_agent import LiteAgentOutput
|
||||
|
||||
assert isinstance(output, TaskOutput) or isinstance(output, LiteAgentOutput), "Output must be a TaskOutput or LiteAgentOutput"
|
||||
|
||||
assert guardrail is not None
|
||||
|
||||
from crewai.utilities.events import (
|
||||
LLMGuardrailCompletedEvent,
|
||||
LLMGuardrailStartedEvent,
|
||||
)
|
||||
from crewai.utilities.events.crewai_event_bus import crewai_event_bus
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
LLMGuardrailStartedEvent(
|
||||
guardrail=guardrail, retry_count=retry_count
|
||||
),
|
||||
)
|
||||
|
||||
result = guardrail(output)
|
||||
guardrail_result = GuardrailResult.from_tuple(result)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
LLMGuardrailCompletedEvent(
|
||||
success=guardrail_result.success,
|
||||
result=guardrail_result.result,
|
||||
error=guardrail_result.error,
|
||||
retry_count=retry_count,
|
||||
),
|
||||
)
|
||||
|
||||
return guardrail_result
|
||||
Reference in New Issue
Block a user