diff --git a/src/crewai/lite_agent.py b/src/crewai/lite_agent.py index d458e6de0..079c4ff3a 100644 --- a/src/crewai/lite_agent.py +++ b/src/crewai/lite_agent.py @@ -154,6 +154,10 @@ class LiteAgent(BaseModel): original_agent: Optional[BaseAgent] = Field( default=None, description="Reference to the agent that created this LiteAgent" ) + guardrail: Optional[str] = Field( + default=None, + description="Description of a guardrail to validate the agent's output", + ) # Private Attributes _parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list) _token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess) @@ -162,7 +166,8 @@ class LiteAgent(BaseModel): _messages: List[Dict[str, str]] = PrivateAttr(default_factory=list) _iterations: int = PrivateAttr(default=0) _printer: Printer = PrivateAttr(default_factory=Printer) - + _guardrail: Optional[Callable] = PrivateAttr(default=None) + @model_validator(mode="after") def setup_llm(self): """Set up the LLM and other components after initialization.""" @@ -183,6 +188,16 @@ class LiteAgent(BaseModel): return self + @model_validator(mode="after") + def init_guardrail(self) -> "LiteAgent": + if isinstance(self.guardrail, str): + from crewai.tasks.llm_guardrail import LLMGuardrail + + assert self.llm is LLM + self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm) + + return self + @property def key(self) -> str: """Get the unique key for this agent instance.""" @@ -351,7 +366,7 @@ class LiteAgent(BaseModel): return formatted_messages - def _invoke_loop(self) -> AgentFinish: + def _invoke_loop(self, context: Optional[str] = None) -> AgentFinish: """ Run the agent's thought process until it reaches a conclusion or max iterations. @@ -462,6 +477,14 @@ class LiteAgent(BaseModel): assert isinstance(formatted_answer, AgentFinish) self._show_logs(formatted_answer) + + if self._guardrail: + success, feedback = self._guardrail(formatted_answer.output) + if not success: + return self._invoke_loop(context=feedback) + else: + return formatted_answer + return formatted_answer def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):