mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
refactor agent parser
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Union
|
|||||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
|
CrewAgentParser,
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
@@ -95,6 +96,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
else self.stop
|
else self.stop
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
self._parser = CrewAgentParser(agent=self)
|
||||||
|
|
||||||
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]:
|
||||||
if "system" in self.prompt:
|
if "system" in self.prompt:
|
||||||
@@ -150,6 +152,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
messages=self.messages,
|
messages=self.messages,
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
callbacks=self.callbacks,
|
callbacks=self.callbacks,
|
||||||
|
parser=self._parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
@@ -161,7 +164,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
printer=self._printer,
|
printer=self._printer,
|
||||||
from_task=self.task
|
from_task=self.task
|
||||||
)
|
)
|
||||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
formatted_answer = process_llm_response(answer, self.use_stop_words, self._parser)
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
# Extract agent fingerprint if available
|
# Extract agent fingerprint if available
|
||||||
|
|||||||
@@ -70,20 +70,6 @@ class CrewAgentParser:
|
|||||||
def __init__(self, agent: Optional[Any] = None):
|
def __init__(self, agent: Optional[Any] = None):
|
||||||
self.agent = agent
|
self.agent = agent
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def parse_text(text: str) -> Union[AgentAction, AgentFinish]:
|
|
||||||
"""
|
|
||||||
Static method to parse text into an AgentAction or AgentFinish without needing to instantiate the class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
text: The text to parse.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Either an AgentAction or AgentFinish based on the parsed content.
|
|
||||||
"""
|
|
||||||
parser = CrewAgentParser()
|
|
||||||
return parser.parse(text)
|
|
||||||
|
|
||||||
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
|
||||||
cleaned_text = self._clean_agent_observations(text)
|
cleaned_text = self._clean_agent_observations(text)
|
||||||
|
|
||||||
@@ -163,14 +149,6 @@ class CrewAgentParser:
|
|||||||
if result in UNABLE_TO_REPAIR_JSON_RESULTS:
|
if result in UNABLE_TO_REPAIR_JSON_RESULTS:
|
||||||
return tool_input
|
return tool_input
|
||||||
|
|
||||||
# if isinstance(result, str) and result.startswith("[") and result.endswith("]"):
|
|
||||||
# try:
|
|
||||||
# result_data = json.loads(result)
|
|
||||||
# if isinstance(result_data, list) and len(result_data) > 0 and isinstance(result_data[0], dict):
|
|
||||||
# return json.dumps(result_data[0])
|
|
||||||
# except Exception:
|
|
||||||
# ...
|
|
||||||
|
|
||||||
return str(result)
|
return str(result)
|
||||||
|
|
||||||
def _create_agent_action(self, thought: str, action_match: dict, cleaned_text: str) -> AgentAction:
|
def _create_agent_action(self, thought: str, action_match: dict, cleaned_text: str) -> AgentAction:
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ from crewai.agents.agent_builder.base_agent import BaseAgent
|
|||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.agents.cache import CacheHandler
|
from crewai.agents.cache import CacheHandler
|
||||||
from crewai.agents.parser import (
|
from crewai.agents.parser import (
|
||||||
|
CrewAgentParser,
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
OutputParserException,
|
OutputParserException,
|
||||||
@@ -204,6 +205,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
_guardrail: Optional[Callable] = PrivateAttr(default=None)
|
||||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||||
|
_parser: CrewAgentParser = PrivateAttr(default_factory=CrewAgentParser)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def setup_llm(self):
|
def setup_llm(self):
|
||||||
@@ -239,6 +241,13 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def setup_parser(self):
|
||||||
|
"""Set up the parser after initialization."""
|
||||||
|
self._parser = CrewAgentParser(agent=self.original_agent)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
@field_validator("guardrail", mode="before")
|
@field_validator("guardrail", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_guardrail_function(
|
def validate_guardrail_function(
|
||||||
@@ -511,6 +520,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
messages=self._messages,
|
messages=self._messages,
|
||||||
llm=cast(LLM, self.llm),
|
llm=cast(LLM, self.llm),
|
||||||
callbacks=self._callbacks,
|
callbacks=self._callbacks,
|
||||||
|
parser=self._parser,
|
||||||
)
|
)
|
||||||
|
|
||||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||||
@@ -553,7 +563,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
|||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
formatted_answer = process_llm_response(answer, self.use_stop_words)
|
formatted_answer = process_llm_response(answer, self.use_stop_words, self._parser)
|
||||||
|
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -71,6 +71,7 @@ def handle_max_iterations_exceeded(
|
|||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
llm: Union[LLM, BaseLLM],
|
llm: Union[LLM, BaseLLM],
|
||||||
callbacks: List[Any],
|
callbacks: List[Any],
|
||||||
|
parser: CrewAgentParser
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""
|
"""
|
||||||
Handles the case when the maximum number of iterations is exceeded.
|
Handles the case when the maximum number of iterations is exceeded.
|
||||||
@@ -109,7 +110,7 @@ def handle_max_iterations_exceeded(
|
|||||||
)
|
)
|
||||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||||
|
|
||||||
formatted_answer = format_answer(answer)
|
formatted_answer = format_answer(parser, answer)
|
||||||
# Return the formatted answer, regardless of its type
|
# Return the formatted answer, regardless of its type
|
||||||
return formatted_answer
|
return formatted_answer
|
||||||
|
|
||||||
@@ -119,10 +120,10 @@ def format_message_for_llm(prompt: str, role: str = "user") -> Dict[str, str]:
|
|||||||
return {"role": role, "content": prompt}
|
return {"role": role, "content": prompt}
|
||||||
|
|
||||||
|
|
||||||
def format_answer(answer: str) -> Union[AgentAction, AgentFinish]:
|
def format_answer(parser: CrewAgentParser, answer: str) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Format a response from the LLM into an AgentAction or AgentFinish."""
|
"""Format a response from the LLM into an AgentAction or AgentFinish."""
|
||||||
try:
|
try:
|
||||||
return CrewAgentParser.parse_text(answer)
|
return parser.parse(answer)
|
||||||
except Exception:
|
except Exception:
|
||||||
# If parsing fails, return a default AgentFinish
|
# If parsing fails, return a default AgentFinish
|
||||||
return AgentFinish(
|
return AgentFinish(
|
||||||
@@ -173,18 +174,18 @@ def get_llm_response(
|
|||||||
|
|
||||||
|
|
||||||
def process_llm_response(
|
def process_llm_response(
|
||||||
answer: str, use_stop_words: bool
|
answer: str, use_stop_words: bool, parser: CrewAgentParser
|
||||||
) -> Union[AgentAction, AgentFinish]:
|
) -> Union[AgentAction, AgentFinish]:
|
||||||
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
"""Process the LLM response and format it into an AgentAction or AgentFinish."""
|
||||||
if not use_stop_words:
|
if not use_stop_words:
|
||||||
try:
|
try:
|
||||||
# Preliminary parsing to check for errors.
|
# Preliminary parsing to check for errors.
|
||||||
format_answer(answer)
|
format_answer(parser, answer)
|
||||||
except OutputParserException as e:
|
except OutputParserException as e:
|
||||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||||
answer = answer.split("Observation:")[0].strip()
|
answer = answer.split("Observation:")[0].strip()
|
||||||
|
|
||||||
return format_answer(answer)
|
return format_answer(parser, answer)
|
||||||
|
|
||||||
|
|
||||||
def handle_agent_action_core(
|
def handle_agent_action_core(
|
||||||
|
|||||||
Reference in New Issue
Block a user