mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 16:48:30 +00:00
kickoff to support guardrail for an LiteAgent
This commit is contained in:
@@ -154,6 +154,10 @@ class LiteAgent(BaseModel):
|
|||||||
original_agent: Optional[BaseAgent] = Field(
|
original_agent: Optional[BaseAgent] = Field(
|
||||||
default=None, description="Reference to the agent that created this LiteAgent"
|
default=None, description="Reference to the agent that created this LiteAgent"
|
||||||
)
|
)
|
||||||
|
guardrail: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Description of a guardrail to validate the agent's output",
|
||||||
|
)
|
||||||
# Private Attributes
|
# Private Attributes
|
||||||
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
_parsed_tools: List[CrewStructuredTool] = PrivateAttr(default_factory=list)
|
||||||
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
_token_process: TokenProcess = PrivateAttr(default_factory=TokenProcess)
|
||||||
@@ -162,7 +166,8 @@ class LiteAgent(BaseModel):
|
|||||||
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
_messages: List[Dict[str, str]] = PrivateAttr(default_factory=list)
|
||||||
_iterations: int = PrivateAttr(default=0)
|
_iterations: int = PrivateAttr(default=0)
|
||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
|
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_llm(self):
|
def setup_llm(self):
|
||||||
"""Set up the LLM and other components after initialization."""
|
"""Set up the LLM and other components after initialization."""
|
||||||
@@ -183,6 +188,16 @@ class LiteAgent(BaseModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def init_guardrail(self) -> "LiteAgent":
|
||||||
|
if isinstance(self.guardrail, str):
|
||||||
|
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||||
|
|
||||||
|
assert self.llm is LLM
|
||||||
|
self._guardrail = LLMGuardrail(description=self.guardrail, llm=self.llm)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def key(self) -> str:
|
def key(self) -> str:
|
||||||
"""Get the unique key for this agent instance."""
|
"""Get the unique key for this agent instance."""
|
||||||
@@ -351,7 +366,7 @@ class LiteAgent(BaseModel):
|
|||||||
|
|
||||||
return formatted_messages
|
return formatted_messages
|
||||||
|
|
||||||
def _invoke_loop(self) -> AgentFinish:
|
def _invoke_loop(self, context: Optional[str] = None) -> AgentFinish:
|
||||||
"""
|
"""
|
||||||
Run the agent's thought process until it reaches a conclusion or max iterations.
|
Run the agent's thought process until it reaches a conclusion or max iterations.
|
||||||
|
|
||||||
@@ -462,6 +477,14 @@ class LiteAgent(BaseModel):
|
|||||||
|
|
||||||
assert isinstance(formatted_answer, AgentFinish)
|
assert isinstance(formatted_answer, AgentFinish)
|
||||||
self._show_logs(formatted_answer)
|
self._show_logs(formatted_answer)
|
||||||
|
|
||||||
|
if self._guardrail:
|
||||||
|
success, feedback = self._guardrail(formatted_answer.output)
|
||||||
|
if not success:
|
||||||
|
return self._invoke_loop(context=feedback)
|
||||||
|
else:
|
||||||
|
return formatted_answer
|
||||||
|
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||||
|
|||||||
Reference in New Issue
Block a user