mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
fix failing tests
This commit is contained in:
@@ -492,21 +492,29 @@ class Crew(BaseModel):
|
|||||||
train_crew = self.copy()
|
train_crew = self.copy()
|
||||||
train_crew._setup_for_training(filename)
|
train_crew._setup_for_training(filename)
|
||||||
|
|
||||||
for n_iteration in range(n_iterations):
|
try:
|
||||||
train_crew._train_iteration = n_iteration
|
for n_iteration in range(n_iterations):
|
||||||
train_crew.kickoff(inputs=inputs)
|
train_crew._train_iteration = n_iteration
|
||||||
|
train_crew.kickoff(inputs=inputs)
|
||||||
|
|
||||||
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
training_data = CrewTrainingHandler(TRAINING_DATA_FILE).load()
|
||||||
|
|
||||||
for agent in train_crew.agents:
|
for agent in train_crew.agents:
|
||||||
if training_data.get(str(agent.id)):
|
if training_data.get(str(agent.id)):
|
||||||
result = TaskEvaluator(agent).evaluate_training_data(
|
result = TaskEvaluator(agent).evaluate_training_data(
|
||||||
training_data=training_data, agent_id=str(agent.id)
|
training_data=training_data, agent_id=str(agent.id)
|
||||||
)
|
)
|
||||||
|
CrewTrainingHandler(filename).save_trained_data(
|
||||||
CrewTrainingHandler(filename).save_trained_data(
|
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
)
|
||||||
)
|
except Exception as e:
|
||||||
|
self._logger.log("error", f"Training failed: {e}", color="red")
|
||||||
|
CrewTrainingHandler(TRAINING_DATA_FILE).clear()
|
||||||
|
CrewTrainingHandler(filename).clear()
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
CrewTrainingHandler(TRAINING_DATA_FILE).close()
|
||||||
|
CrewTrainingHandler(filename).close()
|
||||||
|
|
||||||
def kickoff(
|
def kickoff(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -112,7 +112,6 @@ class TaskEvaluator:
|
|||||||
"Cannot proceed with evaluation.\n"
|
"Cannot proceed with evaluation.\n"
|
||||||
"Please check your training implementation."
|
"Please check your training implementation."
|
||||||
)
|
)
|
||||||
self._logger.log("critical", error_msg, color="red")
|
|
||||||
raise ValueError(error_msg)
|
raise ValueError(error_msg)
|
||||||
|
|
||||||
final_aggregated_data += (
|
final_aggregated_data += (
|
||||||
|
|||||||
Reference in New Issue
Block a user