feat: improve action detection when agent provide multiples choices

This commit is contained in:
Lucas Gomide
2025-07-17 15:50:07 -03:00
parent c212dc2155
commit 63f7d75b34

View File

@@ -89,6 +89,8 @@ class CrewAgentParser:
thought = self._extract_thought(text)
includes_answer = FINAL_ANSWER_ACTION in text
action_match = self._find_last_action_input_pair(cleaned_text)
final_answer = cleaned_text.split(FINAL_ANSWER_ACTION)[-1].strip()
# Check whether the final answer ends with triple backticks.
if final_answer.endswith("```"):
@@ -100,15 +102,7 @@ class CrewAgentParser:
return AgentFinish(thought, final_answer, text)
elif action_match:
action = action_match.group(1)
clean_action = self._clean_action(action)
action_input = action_match.group(2).strip()
tool_input = action_input.strip(" ").strip('"')
safe_tool_input = self._safe_repair_json(tool_input)
return AgentAction(thought, clean_action, safe_tool_input, cleaned_text)
return self._create_agent_action(thought, action_match, cleaned_text)
if not re.search(r"Action\s*\d*\s*:[\s]*(.*?)", cleaned_text, re.DOTALL):
raise OutputParserException(
@@ -173,6 +167,54 @@ class CrewAgentParser:
return str(result)
def _create_agent_action(self, thought: str, action_match: dict, cleaned_text: str) -> AgentAction:
action = action_match["action"]
clean_action = self._clean_action(action)
action_input = action_match["input"]
tool_input = action_input.strip(" ").strip('"')
safe_tool_input = self._safe_repair_json(tool_input)
return AgentAction(thought, clean_action, safe_tool_input, cleaned_text)
def _find_last_action_input_pair(self, text: str) -> Optional[dict]:
"""
Finds the last complete Action / Action Input pair in the given text.
Useful for handling multiple action/observation cycles.
"""
def _match_all_pairs(text: str) -> list[tuple[str, str]]:
pattern = (
r"Action\s*\d*\s*:\s*([^\n]+)" # Action content
r"\s*[\n]+" # Optional whitespace/newline
r"Action\s*\d*\s*Input\s*\d*\s*:\s*" # Action Input label
r"([^\n]*(?:\n(?!Observation:|Thought:|Action\s*\d*\s*:|Final Answer:)[^\n]*)*)"
)
return re.findall(pattern, text, re.MULTILINE | re.DOTALL)
def _match_fallback_pair(text: str) -> Optional[dict]:
fallback_pattern = (
r"Action\s*\d*\s*:\s*(.*?)"
r"\s*Action\s*\d*\s*Input\s*\d*\s*:\s*"
r"(.*?)(?=\n(?:Observation:|Thought:|Action\s*\d*\s*:|Final Answer:)|$)"
)
match = re.search(fallback_pattern, text, re.DOTALL)
if match:
return {
"action": match.group(1).strip(),
"input": match.group(2).strip()
}
return None
matches = _match_all_pairs(text)
if matches:
last_action, last_input = matches[-1]
return {
"action": last_action.strip(),
"input": last_input.strip()
}
return _match_fallback_pair(text)
def _clean_agent_observations(self, text: str) -> str:
"""
Remove agent-written observations from the text.