diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index d1605116c..b4916043c 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -1,7 +1,7 @@ import asyncio import inspect import uuid -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin +from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union, cast, get_args, get_origin, Self from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator, field_validator @@ -18,7 +18,7 @@ from crewai.llm import LLM from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.utilities import I18N -from crewai.utilities.guardrail_result import GuardrailResult +from crewai.utilities.guardrail import process_guardrail from crewai.utilities.agent_utils import ( enforce_rpm_limit, format_message_for_llm, @@ -198,7 +198,7 @@ class LiteAgent(FlowTrackable, BaseModel): return self @model_validator(mode="after") - def ensure_guardrail_is_callable(self) -> "LiteAgent": + def ensure_guardrail_is_callable(self) -> Self: if callable(self.guardrail): self._guardrail = self.guardrail elif isinstance(self.guardrail, str): @@ -228,9 +228,6 @@ class LiteAgent(FlowTrackable, BaseModel): if v is None or isinstance(v, str): return v - if not callable(v): - raise ValueError("Guardrail must be a callable or a string") - # Check function signature sig = inspect.signature(v) if len(sig.parameters) != 1: @@ -352,7 +349,11 @@ class LiteAgent(FlowTrackable, BaseModel): # Process guardrail if set if self._guardrail is not None: - guardrail_result = self._process_guardrail(output) + guardrail_result = process_guardrail( + output=output, + guardrail=self._guardrail, + retry_count=self._guardrail_retry_count + ) if not guardrail_result.success: if self._guardrail_retry_count >= self.guardrail_max_retries: @@ -584,42 +585,3 @@ class LiteAgent(FlowTrackable, BaseModel): def _append_message(self, text: str, role: str = "assistant") -> None: """Append a message to the message list with the given role.""" self._messages.append(format_message_for_llm(text, role=role)) - - def _process_guardrail(self, output: LiteAgentOutput) -> GuardrailResult: - """Process the guardrail for the agent output. - - Args: - output: The output to validate with the guardrail - - Returns: - GuardrailResult: The result of the guardrail validation - """ - assert self._guardrail is not None - - from crewai.utilities.events import ( - LLMGuardrailCompletedEvent, - LLMGuardrailStartedEvent, - ) - from crewai.utilities.events.crewai_event_bus import crewai_event_bus - - crewai_event_bus.emit( - self, - LLMGuardrailStartedEvent( - guardrail=self._guardrail, retry_count=self._guardrail_retry_count - ), - ) - - result = self._guardrail(output) - guardrail_result = GuardrailResult.from_tuple(result) - - crewai_event_bus.emit( - self, - LLMGuardrailCompletedEvent( - success=guardrail_result.success, - result=guardrail_result.result, - error=guardrail_result.error, - retry_count=self._guardrail_retry_count, - ), - ) - - return guardrail_result diff --git a/src/crewai/task.py b/src/crewai/task.py index 3e2278b5e..a320a6896 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -35,12 +35,12 @@ from pydantic_core import PydanticCustomError from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.security import Fingerprint, SecurityConfig -from crewai.utilities.guardrail_result import GuardrailResult from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput from crewai.tools.base_tool import BaseTool from crewai.utilities.config import process_config from crewai.utilities.constants import NOT_SPECIFIED +from crewai.utilities.guardrail import process_guardrail, GuardrailResult from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.events import ( TaskCompletedEvent, @@ -431,7 +431,11 @@ class Task(BaseModel): ) if self._guardrail: - guardrail_result = self._process_guardrail(task_output) + guardrail_result = process_guardrail( + output=task_output, + guardrail=self._guardrail, + retry_count=self.retry_count + ) if not guardrail_result.success: if self.retry_count >= self.max_retries: raise Exception( diff --git a/src/crewai/utilities/guardrail_result.py b/src/crewai/utilities/guardrail.py similarity index 57% rename from src/crewai/utilities/guardrail_result.py rename to src/crewai/utilities/guardrail.py index ba8ebc552..2f159e479 100644 --- a/src/crewai/utilities/guardrail_result.py +++ b/src/crewai/utilities/guardrail.py @@ -1,15 +1,7 @@ -""" -Module for handling task guardrail validation results. - -This module provides the GuardrailResult class which standardizes -the way task guardrails return their validation results. -""" - -from typing import Any, Optional, Tuple, Union +from typing import Any, Callable, Optional, Tuple, Union from pydantic import BaseModel, field_validator - class GuardrailResult(BaseModel): """Result from a task guardrail execution. @@ -54,3 +46,48 @@ class GuardrailResult(BaseModel): result=data if success else None, error=data if not success else None ) + + +def process_guardrail(output: Any, guardrail: Callable, retry_count: int) -> GuardrailResult: + """Process the guardrail for the agent output. + + Args: + output: The output to validate with the guardrail + + Returns: + GuardrailResult: The result of the guardrail validation + """ + from crewai.task import TaskOutput + from crewai.lite_agent import LiteAgentOutput + + assert isinstance(output, TaskOutput) or isinstance(output, LiteAgentOutput), "Output must be a TaskOutput or LiteAgentOutput" + + assert guardrail is not None + + from crewai.utilities.events import ( + LLMGuardrailCompletedEvent, + LLMGuardrailStartedEvent, + ) + from crewai.utilities.events.crewai_event_bus import crewai_event_bus + + crewai_event_bus.emit( + None, + LLMGuardrailStartedEvent( + guardrail=guardrail, retry_count=retry_count + ), + ) + + result = guardrail(output) + guardrail_result = GuardrailResult.from_tuple(result) + + crewai_event_bus.emit( + None, + LLMGuardrailCompletedEvent( + success=guardrail_result.success, + result=guardrail_result.result, + error=guardrail_result.error, + retry_count=retry_count, + ), + ) + + return guardrail_result