mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
kickoff to support guardrail for an LiteAgent
This commit is contained in:
@@ -154,6 +154,10 @@ class LiteAgent(BaseModel):
|
||||
original_agent: Optional[BaseAgent] = Field(
|
||||
default=None, description="Reference to the agent that created this LiteAgent"
|
||||
)
|
||||
guardrail: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Description of a guardrail to validate the agent's output",
|
||||
)
|
||||
# Private Attributes
|
||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||
@@ -162,7 +166,8 @@ class LiteAgent(BaseModel):
|
||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||
_iterations: int = PrivateAttr(default=0)
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
"""Set up the LLM and other components after initialization."""
|
||||
@@ -183,6 +188,16 @@ class LiteAgent(BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def init_guardrail(self) -> "LiteAgent":
|
||||
if isinstance(self.guardrail, str):
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
|
||||
assert self.llm is LLM
|
||||
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||
|
||||
return self
|
||||
|
||||
@property
|
||||
def key(self) -> str:
|
||||
"""Get the unique key for this agent instance."""
|
||||
@@ -351,7 +366,7 @@ class LiteAgent(BaseModel):
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _invoke_loop(self) -> AgentFinish:
|
||||
def _invoke_loop(self, context: Optional[str] = None) -> AgentFinish:
|
||||
"""
|
||||
Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||
|
||||
@@ -462,6 +477,14 @@ class LiteAgent(BaseModel):
|
||||
|
||||
assert isinstance(formatted_answer, AgentFinish)
|
||||
self._show_logs(formatted_answer)
|
||||
|
||||
if self._guardrail:
|
||||
success, feedback = self._guardrail(formatted_answer.output)
|
||||
if not success:
|
||||
return self._invoke_loop(context=feedback)
|
||||
else:
|
||||
return formatted_answer
|
||||
|
||||
return formatted_answer
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
|
||||
Reference in New Issue
Block a user