mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
refactor agent parser
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
CrewAgentParser,
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
@@ -95,6 +96,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
self._parser = CrewAgentParser(agent=self)
|
||||
|
||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||
if "system" in self.prompt:
|
||||
@@ -150,6 +152,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
parser=self._parser,
|
||||
)
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
@@ -161,7 +164,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
printer=self._printer,
|
||||
from_task=self.task
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words, self._parser)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Extract agent fingerprint if available
|
||||
|
||||
@@ -70,20 +70,6 @@ class CrewAgentParser:
|
||||
def __init__(self, agent: Optional[Any] = None):
|
||||
self.agent = agent
|
||||
|
||||
@staticmethod
|
||||
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""
|
||||
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
||||
|
||||
Args:
|
||||
text: The text to parse.
|
||||
|
||||
Returns:
|
||||
Either an AgentAction or AgentFinish based on the parsed content.
|
||||
"""
|
||||
parser = CrewAgentParser()
|
||||
return parser.parse(text)
|
||||
|
||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||
cleaned_text = self._clean_agent_observations(text)
|
||||
|
||||
@@ -163,14 +149,6 @@ class CrewAgentParser:
|
||||
if result in UNABLE_TO_REPAIR_JSON_RESULTS:
|
||||
return tool_input
|
||||
|
||||
# if isinstance(result, str) and result.startswith("[") and result.endswith("]"):
|
||||
# try:
|
||||
# result_data = json.loads(result)
|
||||
# if isinstance(result_data, list) and len(result_data) > 0 and isinstance(result_data[0], dict):
|
||||
# return json.dumps(result_data[0])
|
||||
# except Exception:
|
||||
# ...
|
||||
|
||||
return str(result)
|
||||
|
||||
def _create_agent_action(self, thought: str, action_match: dict, cleaned_text: str) -> AgentAction:
|
||||
|
||||
@@ -35,6 +35,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.agents.parser import (
|
||||
CrewAgentParser,
|
||||
AgentAction,
|
||||
AgentFinish,
|
||||
OutputParserException,
|
||||
@@ -204,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
_parser: CrewAgentParser = PrivateAttr(default_factory=CrewAgentParser)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self):
|
||||
@@ -239,6 +241,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_parser(self):
|
||||
"""Set up the parser after initialization."""
|
||||
self._parser = CrewAgentParser(agent=self.original_agent)
|
||||
|
||||
return self
|
||||
|
||||
@field_validator("guardrail", mode="before")
|
||||
@classmethod
|
||||
def validate_guardrail_function(
|
||||
@@ -511,6 +520,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
messages=self._messages,
|
||||
llm=cast(LLM, self.llm),
|
||||
callbacks=self._callbacks,
|
||||
parser=self._parser,
|
||||
)
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
@@ -553,7 +563,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise e
|
||||
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words, self._parser)
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
try:
|
||||
@@ -622,4 +632,4 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
def _append_message(self, text: str, role: str = "assistant") -> None:
|
||||
"""Append a message to the message list with the given role."""
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
self._messages.append(format_message_for_llm(text, role=role))
|
||||
|
||||
@@ -71,6 +71,7 @@ def handle_max_iterations_exceeded(
|
||||
messages: List[Dict[str, str]],
|
||||
llm: Union[LLM, BaseLLM],
|
||||
callbacks: List[Any],
|
||||
parser: CrewAgentParser
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""
|
||||
Handles the case when the maximum number of iterations is exceeded.
|
||||
@@ -109,7 +110,7 @@ def handle_max_iterations_exceeded(
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
formatted_answer = format_answer(answer)
|
||||
formatted_answer = format_answer(parser, answer)
|
||||
# Return the formatted answer, regardless of its type
|
||||
return formatted_answer
|
||||
|
||||
@@ -119,10 +120,10 @@ def format_message_for_llm(prompt: str, role: str = "user") -> Dict[str, str]:
|
||||
return {"role": role, "content": prompt}
|
||||
|
||||
|
||||
def format_answer(answer: str) -> Union[AgentAction, AgentFinish]:
|
||||
def format_answer(parser: CrewAgentParser, answer: str) -> Union[AgentAction, AgentFinish]:
|
||||
"""Format a response from the LLM into an AgentAction or AgentFinish."""
|
||||
try:
|
||||
return CrewAgentParser.parse_text(answer)
|
||||
return parser.parse(answer)
|
||||
except Exception:
|
||||
# If parsing fails, return a default AgentFinish
|
||||
return AgentFinish(
|
||||
@@ -173,18 +174,18 @@ def get_llm_response(
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
answer: str, use_stop_words: bool, parser: CrewAgentParser
|
||||
) -> Union[AgentAction, AgentFinish]:
|
||||
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
||||
if not use_stop_words:
|
||||
try:
|
||||
# Preliminary parsing to check for errors.
|
||||
format_answer(answer)
|
||||
format_answer(parser, answer)
|
||||
except OutputParserException as e:
|
||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||
answer = answer.split("Observation:")[0].strip()
|
||||
|
||||
return format_answer(answer)
|
||||
return format_answer(parser, answer)
|
||||
|
||||
|
||||
def handle_agent_action_core(
|
||||
|
||||
Reference in New Issue
Block a user