From 25690c861f6fc99f09684e7d4842f01855ca5c8a Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Wed, 29 Jan 2025 14:16:00 -0500 Subject: [PATCH] simplify training --- src/crewai/agents/crew_agent_executor.py | 74 +++++++++++------------- 1 file changed, 35 insertions(+), 39 deletions(-) diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 20ca4c0d1..b144872b1 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -432,58 +432,50 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): ) def _handle_crew_training_output( - self, result: AgentFinish, human_feedback: str | None = None + self, result: AgentFinish, human_feedback: Optional[str] = None ) -> None: - """Function to handle the process of the training data.""" + """Handle the process of saving training data.""" agent_id = str(self.agent.id) # type: ignore + train_iteration = ( + getattr(self.crew, "_train_iteration", None) if self.crew else None + ) + + if train_iteration is None or not isinstance(train_iteration, int): + self._printer.print( + content="Invalid or missing train iteration. Cannot save training data.", + color="red", + ) + return - # Load training data training_handler = CrewTrainingHandler(TRAINING_DATA_FILE) - training_data = training_handler.load() + training_data = training_handler.load() or {} - # Check if training data exists, human input is not requested, and self.crew is valid - if training_data and not self.ask_for_human_input: - if self.crew is not None and hasattr(self.crew, "_train_iteration"): - train_iteration = self.crew._train_iteration - if agent_id in training_data and isinstance(train_iteration, int): - training_data[agent_id][train_iteration][ - "improved_output" - ] = result.output - training_handler.save(training_data) - else: - self._printer.print( - content="Invalid train iteration type or agent_id not in training data.", - color="red", - ) - else: - self._printer.print( - content="Crew is None or does not have _train_iteration attribute.", - color="red", - ) + # Initialize or retrieve agent's training data + agent_training_data = training_data.get(agent_id, {}) - if self.ask_for_human_input and human_feedback is not None: - training_data = { + if human_feedback is not None: + # Save initial output and human feedback + agent_training_data[train_iteration] = { "initial_output": result.output, "human_feedback": human_feedback, - "agent": agent_id, - "agent_role": self.agent.role, # type: ignore } - if self.crew is not None and hasattr(self.crew, "_train_iteration"): - train_iteration = self.crew._train_iteration - if isinstance(train_iteration, int): - CrewTrainingHandler(TRAINING_DATA_FILE).append( - train_iteration, agent_id, training_data - ) - else: - self._printer.print( - content="Invalid train iteration type. Expected int.", - color="red", - ) + else: + # Save improved output + if train_iteration in agent_training_data: + agent_training_data[train_iteration]["improved_output"] = result.output else: self._printer.print( - content="Crew is None or does not have _train_iteration attribute.", + content=( + f"No existing training data for agent {agent_id} and iteration " + f"{train_iteration}. Cannot save improved output." + ), color="red", ) + return + + # Update the training data and save + training_data[agent_id] = agent_training_data + training_handler.save(training_data) def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str: prompt = prompt.replace("{input}", inputs["input"]) @@ -522,6 +514,10 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self, initial_answer: AgentFinish, feedback: str ) -> AgentFinish: """Process feedback for training scenarios with single iteration.""" + self._printer.print( + content="\nProcessing training feedback.\n", + color="yellow", + ) self._handle_crew_training_output(initial_answer, feedback) self.messages.append(self._format_msg(f"Feedback: {feedback}")) improved_answer = self._invoke_loop()