Fixing training while refactoring code

This commit is contained in:
Brandon Hancock
2025-01-28 11:27:15 -05:00
parent c310044bec
commit 6ccede42f7
2 changed files with 88 additions and 70 deletions

View File

@@ -485,82 +485,99 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
return {"role": role, "content": prompt}
def _handle_human_feedback(self, formatted_answer: AgentFinish) -> AgentFinish:
"""
Handles the human feedback loop, allowing the user to provide feedback
on the agent's output and determining if additional iterations are needed.
"""Handle human feedback with different flows for training vs regular use.
Parameters:
formatted_answer (AgentFinish): The initial output from the agent.
Args:
formatted_answer: The initial AgentFinish result to get feedback on
Returns:
AgentFinish: The final output after incorporating human feedback.
AgentFinish: The final answer after processing feedback
"""
human_feedback = self._ask_human_input(formatted_answer.output)
if self._is_training_mode():
return self._handle_training_feedback(formatted_answer, human_feedback)
return self._handle_regular_feedback(formatted_answer, human_feedback)
def _is_training_mode(self) -> bool:
"""Check if crew is in training mode."""
return bool(self.crew and self.crew._train)
def _handle_training_feedback(
self, initial_answer: AgentFinish, feedback: str
) -> AgentFinish:
"""Process feedback for training scenarios with single iteration."""
self._handle_crew_training_output(initial_answer, feedback)
self.messages.append(self._format_msg(f"Feedback: {feedback}"))
improved_answer = self._invoke_loop()
self._handle_crew_training_output(improved_answer)
self.ask_for_human_input = False # Ensure single iteration
return improved_answer
def _handle_regular_feedback(
self, current_answer: AgentFinish, initial_feedback: str
) -> AgentFinish:
"""Process feedback for regular use with potential multiple iterations."""
feedback = initial_feedback
answer = current_answer
while self.ask_for_human_input:
human_feedback = self._ask_human_input(formatted_answer.output)
response = self._get_llm_feedback_response(feedback)
if self.crew and self.crew._train:
self._handle_crew_training_output(formatted_answer, human_feedback)
# Make an LLM call to verify if additional changes are requested based on human feedback
additional_changes_prompt = self._i18n.slice(
"human_feedback_classification"
).format(feedback=human_feedback)
retry_count = 0
llm_call_successful = False
additional_changes_response = None
while retry_count < MAX_LLM_RETRY and not llm_call_successful:
try:
additional_changes_response = (
self.llm.call(
[
self._format_msg(
additional_changes_prompt, role="system"
)
],
callbacks=self.callbacks,
)
.strip()
.lower()
)
llm_call_successful = True
except Exception as e:
retry_count += 1
self._printer.print(
content=f"Error during LLM call to classify human feedback: {e}. Retrying... ({retry_count}/{MAX_LLM_RETRY})",
color="red",
)
if not llm_call_successful:
self._printer.print(
content="Error processing feedback after multiple attempts.",
color="red",
)
if not self._feedback_requires_changes(response):
self.ask_for_human_input = False
break
if additional_changes_response == "false":
self.ask_for_human_input = False
elif additional_changes_response == "true":
self.ask_for_human_input = True
# Add human feedback to messages
self.messages.append(self._format_msg(f"Feedback: {human_feedback}"))
# Invoke the loop again with updated messages
formatted_answer = self._invoke_loop()
if self.crew and self.crew._train:
self._handle_crew_training_output(formatted_answer)
else:
# Unexpected response
self._printer.print(
content=f"Unexpected response from LLM: '{additional_changes_response}'. Assuming no additional changes requested.",
color="red",
)
self.ask_for_human_input = False
answer = self._process_feedback_iteration(feedback)
feedback = self._ask_human_input(answer.output)
return formatted_answer
return answer
def _get_llm_feedback_response(self, feedback: str) -> Optional[str]:
"""Get LLM classification of whether feedback requires changes."""
prompt = self._i18n.slice("human_feedback_classification").format(
feedback=feedback
)
message = self._format_msg(prompt, role="system")
for retry in range(MAX_LLM_RETRY):
try:
response = self.llm.call([message], callbacks=self.callbacks)
return response.strip().lower() if response else None
except Exception as error:
self._log_feedback_error(retry, error)
self._log_max_retries_exceeded()
return None
def _feedback_requires_changes(self, response: Optional[str]) -> bool:
"""Determine if feedback response indicates need for changes."""
return response == "true" if response else False
def _process_feedback_iteration(self, feedback: str) -> AgentFinish:
"""Process a single feedback iteration."""
self.messages.append(self._format_msg(f"Feedback: {feedback}"))
return self._invoke_loop()
def _log_feedback_error(self, retry_count: int, error: Exception) -> None:
"""Log feedback processing errors."""
self._printer.print(
content=(
f"Error processing feedback: {error}. "
f"Retrying... ({retry_count + 1}/{MAX_LLM_RETRY})"
),
color="red",
)
def _log_max_retries_exceeded(self) -> None:
"""Log when max retries for feedback processing are exceeded."""
self._printer.print(
content=(
f"Failed to process feedback after {MAX_LLM_RETRY} attempts. "
"Ending feedback loop."
),
color="red",
)
def _handle_max_iterations_exceeded(self, formatted_answer):
"""

View File

@@ -90,15 +90,16 @@ class TaskEvaluator:
- training_data (dict): The training data to be evaluated.
- agent_id (str): The ID of the agent.
"""
print("Training data: ", training_data)
output_training_data = training_data[agent_id]
final_aggregated_data = ""
for _, data in output_training_data.items():
final_aggregated_data += (
f"Initial Output:\n{data.get('initial_output', '')}\n\n"
f"Human Feedback:\n{data.get('human_feedback', '')}\n\n"
f"Improved Output:\n{data.get('improved_output', '')}\n\n"
f"Initial Output:\n{data.get('initial_output')}\n\n"
f"Human Feedback:\n{data.get('human_feedback')}\n\n"
f"Improved Output:\n{data.get('improved_output')}\n\n"
)
evaluation_query = (