diff --git a/tests/crew_test.py b/tests/crew_test.py index 8a7fc193a..5f31061ec 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -1779,26 +1779,22 @@ def test_crew_train_success( ] ) - crew_training_handler.assert_has_calls( - [ - mock.call("training_data.pkl"), - mock.call().load(), - mock.call("trained_agents_data.pkl"), - mock.call().save_trained_data( - agent_id="Researcher", - trained_data=task_evaluator().evaluate_training_data().model_dump(), - ), - mock.call("trained_agents_data.pkl"), - mock.call().save_trained_data( - agent_id="Senior Writer", - trained_data=task_evaluator().evaluate_training_data().model_dump(), - ), - mock.call(), - mock.call().load(), - mock.call(), - mock.call().load(), - ] - ) + crew_training_handler.assert_any_call("training_data.pkl") + crew_training_handler().load.assert_called() + + crew_training_handler.assert_any_call("trained_agents_data.pkl") + crew_training_handler().load.assert_called() + + crew_training_handler().save_trained_data.assert_has_calls([ + mock.call( + agent_id="Researcher", + trained_data=task_evaluator().evaluate_training_data().model_dump(), + ), + mock.call( + agent_id="Senior Writer", + trained_data=task_evaluator().evaluate_training_data().model_dump(), + ) + ]) def test_crew_train_error():