diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index a22647828..12c5191b1 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -16,7 +16,7 @@ def train_crew(n_iterations: int, filename: str) -> None: if n_iterations <= 0: raise ValueError("The number of iterations must be a positive integer.") - if filename.endswith(".pkl"): + if not filename.endswith(".pkl"): raise ValueError("The filename must not end with .pkl") result = subprocess.run(command, capture_output=False, text=True, check=True) diff --git a/tests/cli/train_crew_test.py b/tests/cli/train_crew_test.py index 9d0d3d4a7..036dd7c2f 100644 --- a/tests/cli/train_crew_test.py +++ b/tests/cli/train_crew_test.py @@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew @mock.patch("crewai.cli.train_crew.subprocess.run") def test_train_crew_positive_iterations(mock_subprocess_run): - # Arrange n_iterations = 5 mock_subprocess_run.return_value = subprocess.CompletedProcess( args=["poetry", "run", "train", str(n_iterations)], @@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run): stderr="", ) - # Act - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") - # Assert mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", str(n_iterations)], + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], capture_output=False, text=True, check=True, @@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run): @mock.patch("crewai.cli.train_crew.click") def test_train_crew_zero_iterations(click): - train_crew(0) + train_crew(0, "trained_agents_data.pkl") click.echo.assert_called_once_with( "An unexpected error occurred: The number of iterations must be a positive integer.", err=True, @@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click): @mock.patch("crewai.cli.train_crew.click") def test_train_crew_negative_iterations(click): - train_crew(-2) + train_crew(-2, "trained_agents_data.pkl") click.echo.assert_called_once_with( "An unexpected error occurred: The number of iterations must be a positive integer.", err=True, @@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click): output="Error", stderr="Some error occurred", ) - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, ) click.echo.assert_has_calls( [ @@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click): @mock.patch("crewai.cli.train_crew.click") @mock.patch("crewai.cli.train_crew.subprocess.run") def test_train_crew_unexpected_exception(mock_subprocess_run, click): - # Arrange n_iterations = 5 mock_subprocess_run.side_effect = Exception("Unexpected error") - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, ) click.echo.assert_called_once_with( "An unexpected error occurred: Unexpected error", err=True