diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index 509b9193a..4f606e213 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -15,18 +15,18 @@ def runner(): def test_train_default_iterations(train_crew, runner): result = runner.invoke(train) - train_crew.assert_called_once_with(5) + train_crew.assert_called_once_with(5, "trained_agents_data.pkl") assert result.exit_code == 0 - assert "Training the crew for 5 iterations" in result.output + assert "Training the Crew for 5 iterations" in result.output @mock.patch("crewai.cli.cli.train_crew") def test_train_custom_iterations(train_crew, runner): result = runner.invoke(train, ["--n_iterations", "10"]) - train_crew.assert_called_once_with(10) + train_crew.assert_called_once_with(10, "trained_agents_data.pkl") assert result.exit_code == 0 - assert "Training the crew for 10 iterations" in result.output + assert "Training the Crew for 10 iterations" in result.output @mock.patch("crewai.cli.cli.train_crew")