From 2ec68501aefb77d331a1eda1e70bca0749cdf74f Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Fri, 11 Oct 2024 19:35:35 -0300 Subject: [PATCH] fix: training issue --- src/crewai/agents/crew_agent_executor.py | 26 ++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 3cb195206..d15c80732 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -334,6 +334,32 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): color="red", ) + if self.ask_for_human_input and human_feedback is not None: + training_data = { + "initial_output": result.output, + "human_feedback": human_feedback, + "agent": agent_id, + "agent_role": self.agent.role, + } + if self.crew is not None and hasattr(self.crew, "_train_iteration"): + train_iteration = self.crew._train_iteration + if isinstance(train_iteration, int): + CrewTrainingHandler(TRAINING_DATA_FILE).append( + train_iteration, agent_id, training_data + ) + else: + self._logger.log( + "error", + "Invalid train iteration type. Expected int.", + color="red", + ) + else: + self._logger.log( + "error", + "Crew is None or does not have _train_iteration attribute.", + color="red", + ) + def _format_prompt(self, prompt: str, inputs: Dict[str, str]) -> str: prompt = prompt.replace("{input}", inputs["input"]) prompt = prompt.replace("{tool_names}", inputs["tool_names"])