Add source to LLM Guardrail events (#3572)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Update Test Durations / update-durations (3.10) (push) Has been cancelled
Update Test Durations / update-durations (3.11) (push) Has been cancelled
Update Test Durations / update-durations (3.12) (push) Has been cancelled
Update Test Durations / update-durations (3.13) (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled

This commit adds the source attribute to LLM Guardrail event calls to
identify the Lite Agent or Task that executed the guardrail.
This commit is contained in:
Vini Brasil
2025-09-22 11:58:00 +09:00
committed by GitHub
parent 9c1096dbdc
commit aa8dc9d77f
5 changed files with 111 additions and 152 deletions

View File

@@ -367,6 +367,7 @@ class LiteAgent(FlowTrackable, BaseModel):
output=output, output=output,
guardrail=self._guardrail, guardrail=self._guardrail,
retry_count=self._guardrail_retry_count, retry_count=self._guardrail_retry_count,
event_source=self,
) )
if not guardrail_result.success: if not guardrail_result.success:

View File

@@ -5,20 +5,14 @@ import logging
import threading import threading
import uuid import uuid
import warnings import warnings
from collections.abc import Callable
from concurrent.futures import Future from concurrent.futures import Future
from copy import copy from copy import copy
from hashlib import md5 from hashlib import md5
from pathlib import Path from pathlib import Path
from typing import ( from typing import (
Any, Any,
Callable,
ClassVar, ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union, Union,
get_args, get_args,
get_origin, get_origin,
@@ -35,20 +29,20 @@ from pydantic import (
from pydantic_core import PydanticCustomError from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_types import (
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.security import Fingerprint, SecurityConfig from crewai.security import Fingerprint, SecurityConfig
from crewai.tasks.output_format import OutputFormat from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified from crewai.utilities.constants import NOT_SPECIFIED, _NotSpecified
from crewai.utilities.guardrail import process_guardrail, GuardrailResult
from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.converter import Converter, convert_to_model
from crewai.events.event_types import ( from crewai.utilities.guardrail import process_guardrail
TaskCompletedEvent,
TaskFailedEvent,
TaskStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.utilities.i18n import I18N from crewai.utilities.i18n import I18N
from crewai.utilities.printer import Printer from crewai.utilities.printer import Printer
from crewai.utilities.string_utils import interpolate_only from crewai.utilities.string_utils import interpolate_only
@@ -85,50 +79,50 @@ class Task(BaseModel):
tools_errors: int = 0 tools_errors: int = 0
delegations: int = 0 delegations: int = 0
i18n: I18N = I18N() i18n: I18N = I18N()
name: Optional[str] = Field(default=None) name: str | None = Field(default=None)
prompt_context: Optional[str] = None prompt_context: str | None = None
description: str = Field(description="Description of the actual task.") description: str = Field(description="Description of the actual task.")
expected_output: str = Field( expected_output: str = Field(
description="Clear definition of expected output for the task." description="Clear definition of expected output for the task."
) )
config: Optional[Dict[str, Any]] = Field( config: dict[str, Any] | None = Field(
description="Configuration for the agent", description="Configuration for the agent",
default=None, default=None,
) )
callback: Optional[Any] = Field( callback: Any | None = Field(
description="Callback to be executed after the task is completed.", default=None description="Callback to be executed after the task is completed.", default=None
) )
agent: Optional[BaseAgent] = Field( agent: BaseAgent | None = Field(
description="Agent responsible for execution the task.", default=None description="Agent responsible for execution the task.", default=None
) )
context: Union[List["Task"], None, _NotSpecified] = Field( context: list["Task"] | None | _NotSpecified = Field(
description="Other tasks that will have their output used as context for this task.", description="Other tasks that will have their output used as context for this task.",
default=NOT_SPECIFIED, default=NOT_SPECIFIED,
) )
async_execution: Optional[bool] = Field( async_execution: bool | None = Field(
description="Whether the task should be executed asynchronously or not.", description="Whether the task should be executed asynchronously or not.",
default=False, default=False,
) )
output_json: Optional[Type[BaseModel]] = Field( output_json: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a JSON output.", description="A Pydantic model to be used to create a JSON output.",
default=None, default=None,
) )
output_pydantic: Optional[Type[BaseModel]] = Field( output_pydantic: type[BaseModel] | None = Field(
description="A Pydantic model to be used to create a Pydantic output.", description="A Pydantic model to be used to create a Pydantic output.",
default=None, default=None,
) )
output_file: Optional[str] = Field( output_file: str | None = Field(
description="A file path to be used to create a file output.", description="A file path to be used to create a file output.",
default=None, default=None,
) )
create_directory: Optional[bool] = Field( create_directory: bool | None = Field(
description="Whether to create the directory for output_file if it doesn't exist.", description="Whether to create the directory for output_file if it doesn't exist.",
default=True, default=True,
) )
output: Optional[TaskOutput] = Field( output: TaskOutput | None = Field(
description="Task output, it's final result after being executed", default=None description="Task output, it's final result after being executed", default=None
) )
tools: Optional[List[BaseTool]] = Field( tools: list[BaseTool] | None = Field(
default_factory=list, default_factory=list,
description="Tools the agent is limited to use for this task.", description="Tools the agent is limited to use for this task.",
) )
@@ -141,24 +135,24 @@ class Task(BaseModel):
frozen=True, frozen=True,
description="Unique identifier for the object, not set by user.", description="Unique identifier for the object, not set by user.",
) )
human_input: Optional[bool] = Field( human_input: bool | None = Field(
description="Whether the task should have a human review the final answer of the agent", description="Whether the task should have a human review the final answer of the agent",
default=False, default=False,
) )
markdown: Optional[bool] = Field( markdown: bool | None = Field(
description="Whether the task should instruct the agent to return the final answer formatted in Markdown", description="Whether the task should instruct the agent to return the final answer formatted in Markdown",
default=False, default=False,
) )
converter_cls: Optional[Type[Converter]] = Field( converter_cls: type[Converter] | None = Field(
description="A converter class used to export structured output", description="A converter class used to export structured output",
default=None, default=None,
) )
processed_by_agents: Set[str] = Field(default_factory=set) processed_by_agents: set[str] = Field(default_factory=set)
guardrail: Optional[Union[Callable[[TaskOutput], Tuple[bool, Any]], str]] = Field( guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str | None = Field(
default=None, default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task", description="Function or string description of a guardrail to validate task output before proceeding to next task",
) )
max_retries: Optional[int] = Field( max_retries: int | None = Field(
default=None, default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0", description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
) )
@@ -166,13 +160,13 @@ class Task(BaseModel):
default=3, description="Maximum number of retries when guardrail fails" default=3, description="Maximum number of retries when guardrail fails"
) )
retry_count: int = Field(default=0, description="Current number of retries") retry_count: int = Field(default=0, description="Current number of retries")
start_time: Optional[datetime.datetime] = Field( start_time: datetime.datetime | None = Field(
default=None, description="Start time of the task execution" default=None, description="Start time of the task execution"
) )
end_time: Optional[datetime.datetime] = Field( end_time: datetime.datetime | None = Field(
default=None, description="End time of the task execution" default=None, description="End time of the task execution"
) )
allow_crewai_trigger_context: Optional[bool] = Field( allow_crewai_trigger_context: bool | None = Field(
default=None, default=None,
description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.", description="Whether this task should append 'Trigger Payload: {crewai_trigger_payload}' to the task description when crewai_trigger_payload exists in crew inputs.",
) )
@@ -181,8 +175,8 @@ class Task(BaseModel):
@field_validator("guardrail") @field_validator("guardrail")
@classmethod @classmethod
def validate_guardrail_function( def validate_guardrail_function(
cls, v: Optional[str | Callable] cls, v: str | Callable | None
) -> Optional[str | Callable]: ) -> str | Callable | None:
""" """
If v is a callable, validate that the guardrail function has the correct signature and behavior. If v is a callable, validate that the guardrail function has the correct signature and behavior.
If v is a string, return it as is. If v is a string, return it as is.
@@ -229,7 +223,7 @@ class Task(BaseModel):
return_annotation_args[1] is Any return_annotation_args[1] is Any
or return_annotation_args[1] is str or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput] or return_annotation_args[1] == str | TaskOutput
) )
): ):
raise ValueError( raise ValueError(
@@ -237,11 +231,11 @@ class Task(BaseModel):
) )
return v return v
_guardrail: Optional[Callable] = PrivateAttr(default=None) _guardrail: Callable | None = PrivateAttr(default=None)
_original_description: Optional[str] = PrivateAttr(default=None) _original_description: str | None = PrivateAttr(default=None)
_original_expected_output: Optional[str] = PrivateAttr(default=None) _original_expected_output: str | None = PrivateAttr(default=None)
_original_output_file: Optional[str] = PrivateAttr(default=None) _original_output_file: str | None = PrivateAttr(default=None)
_thread: Optional[threading.Thread] = PrivateAttr(default=None) _thread: threading.Thread | None = PrivateAttr(default=None)
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
@@ -265,7 +259,9 @@ class Task(BaseModel):
elif isinstance(self.guardrail, str): elif isinstance(self.guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail from crewai.tasks.llm_guardrail import LLMGuardrail
assert self.agent is not None if self.agent is None:
raise ValueError("Agent is required to use LLMGuardrail")
self._guardrail = LLMGuardrail( self._guardrail = LLMGuardrail(
description=self.guardrail, llm=self.agent.llm description=self.guardrail, llm=self.agent.llm
) )
@@ -274,7 +270,7 @@ class Task(BaseModel):
@field_validator("id", mode="before") @field_validator("id", mode="before")
@classmethod @classmethod
def _deny_user_set_id(cls, v: Optional[UUID4]) -> None: def _deny_user_set_id(cls, v: UUID4 | None) -> None:
if v: if v:
raise PydanticCustomError( raise PydanticCustomError(
"may_not_set_field", "This field is not to be set by the user.", {} "may_not_set_field", "This field is not to be set by the user.", {}
@@ -282,7 +278,7 @@ class Task(BaseModel):
@field_validator("output_file") @field_validator("output_file")
@classmethod @classmethod
def output_file_validation(cls, value: Optional[str]) -> Optional[str]: def output_file_validation(cls, value: str | None) -> str | None:
"""Validate the output file path. """Validate the output file path.
Args: Args:
@@ -307,7 +303,7 @@ class Task(BaseModel):
) )
# Check for shell expansion first # Check for shell expansion first
if value.startswith("~") or value.startswith("$"): if value.startswith(("~", "$")):
raise ValueError( raise ValueError(
"Shell expansion characters are not allowed in output_file paths" "Shell expansion characters are not allowed in output_file paths"
) )
@@ -373,9 +369,9 @@ class Task(BaseModel):
def execute_sync( def execute_sync(
self, self,
agent: Optional[BaseAgent] = None, agent: BaseAgent | None = None,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> TaskOutput: ) -> TaskOutput:
"""Execute the task synchronously.""" """Execute the task synchronously."""
return self._execute_core(agent, context, tools) return self._execute_core(agent, context, tools)
@@ -397,8 +393,8 @@ class Task(BaseModel):
def execute_async( def execute_async(
self, self,
agent: BaseAgent | None = None, agent: BaseAgent | None = None,
context: Optional[str] = None, context: str | None = None,
tools: Optional[List[BaseTool]] = None, tools: list[BaseTool] | None = None,
) -> Future[TaskOutput]: ) -> Future[TaskOutput]:
"""Execute the task asynchronously.""" """Execute the task asynchronously."""
future: Future[TaskOutput] = Future() future: Future[TaskOutput] = Future()
@@ -411,9 +407,9 @@ class Task(BaseModel):
def _execute_task_async( def _execute_task_async(
self, self,
agent: Optional[BaseAgent], agent: BaseAgent | None,
context: Optional[str], context: str | None,
tools: Optional[List[Any]], tools: list[Any] | None,
future: Future[TaskOutput], future: Future[TaskOutput],
) -> None: ) -> None:
"""Execute the task asynchronously with context handling.""" """Execute the task asynchronously with context handling."""
@@ -422,9 +418,9 @@ class Task(BaseModel):
def _execute_core( def _execute_core(
self, self,
agent: Optional[BaseAgent], agent: BaseAgent | None,
context: Optional[str], context: str | None,
tools: Optional[List[Any]], tools: list[Any] | None,
) -> TaskOutput: ) -> TaskOutput:
"""Run the core execution logic of the task.""" """Run the core execution logic of the task."""
try: try:
@@ -465,6 +461,7 @@ class Task(BaseModel):
output=task_output, output=task_output,
guardrail=self._guardrail, guardrail=self._guardrail,
retry_count=self.retry_count, retry_count=self.retry_count,
event_source=self,
) )
if not guardrail_result.success: if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries: if self.retry_count >= self.guardrail_max_retries:
@@ -528,41 +525,6 @@ class Task(BaseModel):
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
raise e # Re-raise the exception after emitting the event raise e # Re-raise the exception after emitting the event
def _process_guardrail(self, task_output: TaskOutput) -> GuardrailResult:
assert self._guardrail is not None
from crewai.events.event_types import (
LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent,
)
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus.emit(
self,
LLMGuardrailStartedEvent(
guardrail=self._guardrail, retry_count=self.retry_count
),
)
try:
result = self._guardrail(task_output)
guardrail_result = GuardrailResult.from_tuple(result)
except Exception as e:
guardrail_result = GuardrailResult(
success=False, result=None, error=f"Guardrail execution error: {str(e)}"
)
crewai_event_bus.emit(
self,
LLMGuardrailCompletedEvent(
success=guardrail_result.success,
result=guardrail_result.result,
error=guardrail_result.error,
retry_count=self.retry_count,
),
)
return guardrail_result
def prompt(self) -> str: def prompt(self) -> str:
"""Generates the task prompt with optional markdown formatting. """Generates the task prompt with optional markdown formatting.
@@ -604,7 +566,7 @@ Follow these guidelines:
return "\n".join(tasks_slices) return "\n".join(tasks_slices)
def interpolate_inputs_and_add_conversation_history( def interpolate_inputs_and_add_conversation_history(
self, inputs: Dict[str, Union[str, int, float, Dict[str, Any], List[Any]]] self, inputs: dict[str, str | int | float | dict[str, Any] | list[Any]]
) -> None: ) -> None:
"""Interpolate inputs into the task description, expected output, and output file path. """Interpolate inputs into the task description, expected output, and output file path.
Add conversation history if present. Add conversation history if present.
@@ -635,14 +597,14 @@ Follow these guidelines:
f"Missing required template variable '{e.args[0]}' in description" f"Missing required template variable '{e.args[0]}' in description"
) from e ) from e
except ValueError as e: except ValueError as e:
raise ValueError(f"Error interpolating description: {str(e)}") from e raise ValueError(f"Error interpolating description: {e!s}") from e
try: try:
self.expected_output = interpolate_only( self.expected_output = interpolate_only(
input_string=self._original_expected_output, inputs=inputs input_string=self._original_expected_output, inputs=inputs
) )
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
raise ValueError(f"Error interpolating expected_output: {str(e)}") from e raise ValueError(f"Error interpolating expected_output: {e!s}") from e
if self.output_file is not None: if self.output_file is not None:
try: try:
@@ -650,11 +612,9 @@ Follow these guidelines:
input_string=self._original_output_file, inputs=inputs input_string=self._original_output_file, inputs=inputs
) )
except (KeyError, ValueError) as e: except (KeyError, ValueError) as e:
raise ValueError( raise ValueError(f"Error interpolating output_file path: {e!s}") from e
f"Error interpolating output_file path: {str(e)}"
) from e
if "crew_chat_messages" in inputs and inputs["crew_chat_messages"]: if inputs.get("crew_chat_messages"):
conversation_instruction = self.i18n.slice( conversation_instruction = self.i18n.slice(
"conversation_history_instruction" "conversation_history_instruction"
) )
@@ -681,14 +641,14 @@ Follow these guidelines:
"""Increment the tools errors counter.""" """Increment the tools errors counter."""
self.tools_errors += 1 self.tools_errors += 1
def increment_delegations(self, agent_name: Optional[str]) -> None: def increment_delegations(self, agent_name: str | None) -> None:
"""Increment the delegations counter.""" """Increment the delegations counter."""
if agent_name: if agent_name:
self.processed_by_agents.add(agent_name) self.processed_by_agents.add(agent_name)
self.delegations += 1 self.delegations += 1
def copy( def copy( # type: ignore
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"] self, agents: list["BaseAgent"], task_mapping: dict[str, "Task"]
) -> "Task": ) -> "Task":
"""Creates a deep copy of the Task while preserving its original class type. """Creates a deep copy of the Task while preserving its original class type.
@@ -721,20 +681,18 @@ Follow these guidelines:
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else [] cloned_tools = copy(self.tools) if self.tools else []
copied_task = self.__class__( return self.__class__(
**copied_data, **copied_data,
context=cloned_context, context=cloned_context,
agent=cloned_agent, agent=cloned_agent,
tools=cloned_tools, tools=cloned_tools,
) )
return copied_task
def _export_output( def _export_output(
self, result: str self, result: str
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]: ) -> tuple[BaseModel | None, dict[str, Any] | None]:
pydantic_output: Optional[BaseModel] = None pydantic_output: BaseModel | None = None
json_output: Optional[Dict[str, Any]] = None json_output: dict[str, Any] | None = None
if self.output_pydantic or self.output_json: if self.output_pydantic or self.output_json:
model_output = convert_to_model( model_output = convert_to_model(
@@ -764,7 +722,7 @@ Follow these guidelines:
return OutputFormat.PYDANTIC return OutputFormat.PYDANTIC
return OutputFormat.RAW return OutputFormat.RAW
def _save_file(self, result: Union[Dict, str, Any]) -> None: def _save_file(self, result: dict | str | Any) -> None:
"""Save task output to a file. """Save task output to a file.
Note: Note:
@@ -785,7 +743,7 @@ Follow these guidelines:
if self.output_file is None: if self.output_file is None:
raise ValueError("output_file is not set.") raise ValueError("output_file is not set.")
FILEWRITER_RECOMMENDATION = ( filewriter_recommendation = (
"For cross-platform file writing, especially on Windows, " "For cross-platform file writing, especially on Windows, "
"use FileWriterTool from crewai_tools package." "use FileWriterTool from crewai_tools package."
) )
@@ -811,10 +769,10 @@ Follow these guidelines:
except (OSError, IOError) as e: except (OSError, IOError) as e:
raise RuntimeError( raise RuntimeError(
"\n".join( "\n".join(
[f"Failed to save output file: {e}", FILEWRITER_RECOMMENDATION] [f"Failed to save output file: {e}", filewriter_recommendation]
) )
) ) from e
return None return
def __repr__(self): def __repr__(self):
return f"Task(description={self.description}, expected_output={self.expected_output})" return f"Task(description={self.description}, expected_output={self.expected_output})"

View File

@@ -1,4 +1,5 @@
from typing import Any, Callable, Optional, Tuple, Union from collections.abc import Callable
from typing import Any
from pydantic import BaseModel, field_validator from pydantic import BaseModel, field_validator
@@ -17,8 +18,8 @@ class GuardrailResult(BaseModel):
""" """
success: bool success: bool
result: Optional[Any] = None result: Any | None = None
error: Optional[str] = None error: str | None = None
@field_validator("result", "error") @field_validator("result", "error")
@classmethod @classmethod
@@ -36,7 +37,7 @@ class GuardrailResult(BaseModel):
return v return v
@classmethod @classmethod
def from_tuple(cls, result: Tuple[bool, Union[Any, str]]) -> "GuardrailResult": def from_tuple(cls, result: tuple[bool, Any | str]) -> "GuardrailResult":
"""Create a GuardrailResult from a validation tuple. """Create a GuardrailResult from a validation tuple.
Args: Args:
@@ -55,33 +56,27 @@ class GuardrailResult(BaseModel):
def process_guardrail( def process_guardrail(
output: Any, guardrail: Callable, retry_count: int output: Any, guardrail: Callable, retry_count: int, event_source: Any | None = None
) -> GuardrailResult: ) -> GuardrailResult:
"""Process the guardrail for the agent output. """Process the guardrail for the agent output.
Args: Args:
output: The output to validate with the guardrail output: The output to validate with the guardrail
guardrail: The guardrail to validate the output with
retry_count: The number of times the guardrail has been retried
event_source: The source of the guardrail to be sent in events
Returns: Returns:
GuardrailResult: The result of the guardrail validation GuardrailResult: The result of the guardrail validation
""" """
from crewai.task import TaskOutput from crewai.events.event_bus import crewai_event_bus
from crewai.lite_agent import LiteAgentOutput
assert isinstance(output, TaskOutput) or isinstance(
output, LiteAgentOutput
), "Output must be a TaskOutput or LiteAgentOutput"
assert guardrail is not None
from crewai.events.types.llm_guardrail_events import ( from crewai.events.types.llm_guardrail_events import (
LLMGuardrailCompletedEvent, LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent, LLMGuardrailStartedEvent,
) )
from crewai.events.event_bus import crewai_event_bus
crewai_event_bus.emit( crewai_event_bus.emit(
None, event_source,
LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=retry_count), LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=retry_count),
) )
@@ -89,7 +84,7 @@ def process_guardrail(
guardrail_result = GuardrailResult.from_tuple(result) guardrail_result = GuardrailResult.from_tuple(result)
crewai_event_bus.emit( crewai_event_bus.emit(
None, event_source,
LLMGuardrailCompletedEvent( LLMGuardrailCompletedEvent(
success=guardrail_result.success, success=guardrail_result.success,
result=guardrail_result.result, result=guardrail_result.result,

View File

@@ -1,4 +1,3 @@
# ruff: noqa: S101
# mypy: ignore-errors # mypy: ignore-errors
from collections import defaultdict from collections import defaultdict
from typing import cast from typing import cast
@@ -329,23 +328,27 @@ def test_guardrail_is_called_using_string():
LLMGuardrailStartedEvent, LLMGuardrailStartedEvent,
) )
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail="""Only include Brazilian players, both women and men""",
)
with crewai_event_bus.scoped_handlers(): with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMGuardrailStartedEvent) @crewai_event_bus.on(LLMGuardrailStartedEvent)
def capture_guardrail_started(source, event): def capture_guardrail_started(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["started"].append(event) guardrail_events["started"].append(event)
@crewai_event_bus.on(LLMGuardrailCompletedEvent) @crewai_event_bus.on(LLMGuardrailCompletedEvent)
def capture_guardrail_completed(source, event): def capture_guardrail_completed(source, event):
assert isinstance(source, LiteAgent)
assert source.original_agent == agent
guardrail_events["completed"].append(event) guardrail_events["completed"].append(event)
agent = Agent(
role="Sports Analyst",
goal="Gather information about the best soccer players",
backstory="""You are an expert at gathering and organizing information. You carefully collect details and present them in a structured way.""",
guardrail="""Only include Brazilian players, both women and men""",
)
result = agent.kickoff(messages="Top 10 best players in the world?") result = agent.kickoff(messages="Top 10 best players in the world?")
assert len(guardrail_events["started"]) == 2 assert len(guardrail_events["started"]) == 2

View File

@@ -3,15 +3,15 @@ from unittest.mock import Mock, patch
import pytest import pytest
from crewai import Agent, Task from crewai import Agent, Task
from crewai.llm import LLM from crewai.events.event_bus import crewai_event_bus
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
from crewai.events.event_types import ( from crewai.events.event_types import (
LLMGuardrailCompletedEvent, LLMGuardrailCompletedEvent,
LLMGuardrailStartedEvent, LLMGuardrailStartedEvent,
) )
from crewai.events.event_bus import crewai_event_bus from crewai.llm import LLM
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
def test_task_without_guardrail(): def test_task_without_guardrail():
@@ -177,16 +177,25 @@ def test_guardrail_emits_events(sample_agent):
started_guardrail = [] started_guardrail = []
completed_guardrail = [] completed_guardrail = []
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
guardrail="Ensure the authors are from Italy",
)
with crewai_event_bus.scoped_handlers(): with crewai_event_bus.scoped_handlers():
@crewai_event_bus.on(LLMGuardrailStartedEvent) @crewai_event_bus.on(LLMGuardrailStartedEvent)
def handle_guardrail_started(source, event): def handle_guardrail_started(source, event):
assert source == task
started_guardrail.append( started_guardrail.append(
{"guardrail": event.guardrail, "retry_count": event.retry_count} {"guardrail": event.guardrail, "retry_count": event.retry_count}
) )
@crewai_event_bus.on(LLMGuardrailCompletedEvent) @crewai_event_bus.on(LLMGuardrailCompletedEvent)
def handle_guardrail_completed(source, event): def handle_guardrail_completed(source, event):
assert source == task
completed_guardrail.append( completed_guardrail.append(
{ {
"success": event.success, "success": event.success,
@@ -196,13 +205,6 @@ def test_guardrail_emits_events(sample_agent):
} }
) )
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
guardrail="Ensure the authors are from Italy",
)
result = task.execute_sync(agent=sample_agent) result = task.execute_sync(agent=sample_agent)
def custom_guardrail(result: TaskOutput): def custom_guardrail(result: TaskOutput):