Fixing training while refactoring code

This commit is contained in:
Brandon Hancock
2025-01-28 11:27:15 -05:00
parent c310044bec
commit 6ccede42f7
2 changed files with 88 additions and 70 deletions

View File

@@ -485,82 +485,99 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return {"role": role, "content": prompt} return {"role": role, "content": prompt}
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish: def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
""" """Handle human feedback with different flows for training vs regular use.
Handles the human feedback loop, allowing the user to provide feedback
on the agent's output and determining if additional iterations are needed.
Parameters: Args:
formatted_answer (AgentFinish): The initial output from the agent. formatted_answer: The initial AgentFinish result to get feedback on
Returns: Returns:
AgentFinish: The final output after incorporating human feedback. AgentFinish: The final answer after processing feedback
""" """
while self.ask_for_human_input:
human_feedback = self._ask_human_input(formatted_answer.output) human_feedback = self._ask_human_input(formatted_answer.output)
if self.crew and self.crew._train: if self._is_training_mode():
self._handle_crew_training_output(formatted_answer, human_feedback) return self._handle_training_feedback(formatted_answer, human_feedback)
# Make an LLM call to verify if additional changes are requested based on human feedback return self._handle_regular_feedback(formatted_answer, human_feedback)
additional_changes_prompt = self._i18n.slice(
"human_feedback_classification"
).format(feedback=human_feedback)
retry_count = 0 def _is_training_mode(self) -> bool:
llm_call_successful = False """Check if crew is in training mode."""
additional_changes_response = None return bool(self.crew and self.crew._train)
while retry_count < MAX_LLM_RETRY and not llm_call_successful: def _handle_training_feedback(
try: self, initial_answer: AgentFinish, feedback: str
additional_changes_response = ( ) -> AgentFinish:
self.llm.call( """Process feedback for training scenarios with single iteration."""
[ self._handle_crew_training_output(initial_answer, feedback)
self._format_msg( self.messages.append(self._format_msg(f"Feedback: {feedback}"))
additional_changes_prompt, role="system" improved_answer = self._invoke_loop()
) self._handle_crew_training_output(improved_answer)
], self.ask_for_human_input = False # Ensure single iteration
callbacks=self.callbacks, return improved_answer
)
.strip()
.lower()
)
llm_call_successful = True
except Exception as e:
retry_count += 1
self._printer.print( def _handle_regular_feedback(
content=f"Error during LLM call to classify human feedback: {e}. Retrying... ({retry_count}/{MAX_LLM_RETRY})", self, current_answer: AgentFinish, initial_feedback: str
color="red", ) -> AgentFinish:
) """Process feedback for regular use with potential multiple iterations."""
feedback = initial_feedback
answer = current_answer
if not llm_call_successful: while self.ask_for_human_input:
self._printer.print( response = self._get_llm_feedback_response(feedback)
content="Error processing feedback after multiple attempts.",
color="red", if not self._feedback_requires_changes(response):
)
self.ask_for_human_input = False self.ask_for_human_input = False
break
if additional_changes_response == "false":
self.ask_for_human_input = False
elif additional_changes_response == "true":
self.ask_for_human_input = True
# Add human feedback to messages
self.messages.append(self._format_msg(f"Feedback: {human_feedback}"))
# Invoke the loop again with updated messages
formatted_answer = self._invoke_loop()
if self.crew and self.crew._train:
self._handle_crew_training_output(formatted_answer)
else: else:
# Unexpected response answer = self._process_feedback_iteration(feedback)
feedback = self._ask_human_input(answer.output)
return answer
def _get_llm_feedback_response(self, feedback: str) -> Optional[str]:
"""Get LLM classification of whether feedback requires changes."""
prompt = self._i18n.slice("human_feedback_classification").format(
feedback=feedback
)
message = self._format_msg(prompt, role="system")
for retry in range(MAX_LLM_RETRY):
try:
response = self.llm.call([message], callbacks=self.callbacks)
return response.strip().lower() if response else None
except Exception as error:
self._log_feedback_error(retry, error)
self._log_max_retries_exceeded()
return None
def _feedback_requires_changes(self, response: Optional[str]) -> bool:
"""Determine if feedback response indicates need for changes."""
return response == "true" if response else False
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
"""Process a single feedback iteration."""
self.messages.append(self._format_msg(f"Feedback: {feedback}"))
return self._invoke_loop()
def _log_feedback_error(self, retry_count: int, error: Exception) -> None:
"""Log feedback processing errors."""
self._printer.print( self._printer.print(
content=f"Unexpected response from LLM: '{additional_changes_response}'. Assuming no additional changes requested.", content=(
f"Error processing feedback: {error}. "
f"Retrying... ({retry_count + 1}/{MAX_LLM_RETRY})"
),
color="red", color="red",
) )
self.ask_for_human_input = False
return formatted_answer def _log_max_retries_exceeded(self) -> None:
"""Log when max retries for feedback processing are exceeded."""
self._printer.print(
content=(
f"Failed to process feedback after {MAX_LLM_RETRY} attempts. "
"Ending feedback loop."
),
color="red",
)
def _handle_max_iterations_exceeded(self, formatted_answer): def _handle_max_iterations_exceeded(self, formatted_answer):
""" """

View File

@@ -90,15 +90,16 @@ class TaskEvaluator:
- training_data (dict): The training data to be evaluated. - training_data (dict): The training data to be evaluated.
- agent_id (str): The ID of the agent. - agent_id (str): The ID of the agent.
""" """
print("Training data: ", training_data)
output_training_data = training_data[agent_id] output_training_data = training_data[agent_id]
final_aggregated_data = "" final_aggregated_data = ""
for _, data in output_training_data.items(): for _, data in output_training_data.items():
final_aggregated_data += ( final_aggregated_data += (
f"Initial Output:\n{data.get('initial_output', '')}\n\n" f"Initial Output:\n{data.get('initial_output')}\n\n"
f"Human Feedback:\n{data.get('human_feedback', '')}\n\n" f"Human Feedback:\n{data.get('human_feedback')}\n\n"
f"Improved Output:\n{data.get('improved_output', '')}\n\n" f"Improved Output:\n{data.get('improved_output')}\n\n"
) )
evaluation_query = ( evaluation_query = (