From 10b42bd0d49d6ef04a039ff3ed55fa4ad4e52a5d Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Tue, 28 Jan 2025 14:25:10 -0500 Subject: [PATCH] fix failing tests --- src/crewai/crew.py | 34 ++++++++++++------- .../utilities/evaluators/task_evaluator.py | 1 - 2 files changed, 21 insertions(+), 14 deletions(-) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 5d4b9ff79..96480ed04 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -492,21 +492,29 @@ class Crew(BaseModel): train_crew = self.copy() train_crew._setup_for_training(filename) - for n_iteration in range(n_iterations): - train_crew._train_iteration = n_iteration - train_crew.kickoff(inputs=inputs) + try: + for n_iteration in range(n_iterations): + train_crew._train_iteration = n_iteration + train_crew.kickoff(inputs=inputs) - training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() + training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load() - for agent in train_crew.agents: - if training_data.get(str(agent.id)): - result = TaskEvaluator(agent).evaluate_training_data( - training_data=training_data, agent_id=str(agent.id) - ) - - CrewTrainingHandler(filename).save_trained_data( - agent_id=str(agent.role), trained_data=result.model_dump() - ) + for agent in train_crew.agents: + if training_data.get(str(agent.id)): + result = TaskEvaluator(agent).evaluate_training_data( + training_data=training_data, agent_id=str(agent.id) + ) + CrewTrainingHandler(filename).save_trained_data( + agent_id=str(agent.role), trained_data=result.model_dump() + ) + except Exception as e: + self._logger.log("error", f"Training failed: {e}", color="red") + CrewTrainingHandler(TRAINING_DATA_FILE).clear() + CrewTrainingHandler(filename).clear() + raise + finally: + CrewTrainingHandler(TRAINING_DATA_FILE).close() + CrewTrainingHandler(filename).close() def kickoff( self, diff --git a/src/crewai/utilities/evaluators/task_evaluator.py b/src/crewai/utilities/evaluators/task_evaluator.py index 0b7b8cca7..294629274 100644 --- a/src/crewai/utilities/evaluators/task_evaluator.py +++ b/src/crewai/utilities/evaluators/task_evaluator.py @@ -112,7 +112,6 @@ class TaskEvaluator: "Cannot proceed with evaluation.\n" "Please check your training implementation." ) - self._logger.log("critical", error_msg, color="red") raise ValueError(error_msg) final_aggregated_data += (