diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 9d748715a..bc6d5c21d 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -60,10 +60,17 @@ def version(tools): default=5, help="Number of iterations to train the crew", ) -def train(n_iterations: int): +@click.option( + "-f", + "--filename", + type=str, + default="trained_agents_data.pkl", + help="Path to a custom file for training", +) +def train(n_iterations: int, filename: str): """Train the crew.""" - click.echo(f"Training the crew for {n_iterations} iterations") - train_crew(n_iterations) + click.echo(f"Training the Crew for {n_iterations} iterations") + train_crew(n_iterations, filename) @crewai.command() diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 94ac35b93..86beeeaca 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -25,7 +25,7 @@ def train(): "topic": "AI LLMs" } try: - {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs) + {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs) except Exception as e: raise Exception(f"An error occurred while training the crew: {e}") diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index cd880db5d..12c5191b1 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -3,19 +3,22 @@ import subprocess import click -def train_crew(n_iterations: int) -> None: +def train_crew(n_iterations: int, filename: str) -> None: """ Train the crew by running a command in the Poetry environment. Args: n_iterations (int): The number of iterations to train the crew. """ - command = ["poetry", "run", "train", str(n_iterations)] + command = ["poetry", "run", "train", str(n_iterations), filename] try: if n_iterations <= 0: raise ValueError("The number of iterations must be a positive integer.") + if not filename.endswith(".pkl"): + raise ValueError("The filename must not end with .pkl") + result = subprocess.run(command, capture_output=False, text=True, check=True) if result.stderr: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d7998ecff..f3f032294 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -34,7 +34,9 @@ from crewai.telemetry import Telemetry from crewai.tools.agent_tools import AgentTools from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController -from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE +from crewai.utilities.constants import ( + TRAINING_DATA_FILE, +) from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.formatter import ( @@ -388,7 +390,7 @@ class Crew(BaseModel): del task_config["agent"] return Task(**task_config, agent=task_agent) - def _setup_for_training(self) -> None: + def _setup_for_training(self, filename: str) -> None: """Sets up the crew for training.""" self._train = True @@ -399,11 +401,13 @@ class Crew(BaseModel): agent.allow_delegation = False CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file() - CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file() + CrewTrainingHandler(filename).initialize_file() - def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None: + def train( + self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {} + ) -> None: """Trains the crew for a given number of iterations.""" - self._setup_for_training() + self._setup_for_training(filename) for n_iteration in range(n_iterations): self._train_iteration = n_iteration @@ -416,7 +420,7 @@ class Crew(BaseModel): training_data=training_data, agent_id=str(agent.id) ) - CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data( + CrewTrainingHandler(filename).save_trained_data( agent_id=str(agent.role), trained_data=result.model_dump() ) diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 68c33241d..1125cae4e 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -1,7 +1,5 @@ import os import pickle - - from datetime import datetime @@ -32,14 +30,16 @@ class PickleHandler: Parameters: - file_name (str): The name of the file for saving and loading data. """ + if not file_name.endswith(".pkl"): + file_name += ".pkl" + self.file_path = os.path.join(os.getcwd(), file_name) def initialize_file(self) -> None: """ - Initialize the file with an empty dictionary if it does not exist or is empty. + Initialize the file with an empty dictionary and overwrite any existing data. """ - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - self.save({}) # Save an empty dictionary to initialize the file + self.save({}) def save(self, data) -> None: """ diff --git a/tests/cli/cli_test.py b/tests/cli/cli_test.py index 509b9193a..4f606e213 100644 --- a/tests/cli/cli_test.py +++ b/tests/cli/cli_test.py @@ -15,18 +15,18 @@ def runner(): def test_train_default_iterations(train_crew, runner): result = runner.invoke(train) - train_crew.assert_called_once_with(5) + train_crew.assert_called_once_with(5, "trained_agents_data.pkl") assert result.exit_code == 0 - assert "Training the crew for 5 iterations" in result.output + assert "Training the Crew for 5 iterations" in result.output @mock.patch("crewai.cli.cli.train_crew") def test_train_custom_iterations(train_crew, runner): result = runner.invoke(train, ["--n_iterations", "10"]) - train_crew.assert_called_once_with(10) + train_crew.assert_called_once_with(10, "trained_agents_data.pkl") assert result.exit_code == 0 - assert "Training the crew for 10 iterations" in result.output + assert "Training the Crew for 10 iterations" in result.output @mock.patch("crewai.cli.cli.train_crew") diff --git a/tests/cli/train_crew_test.py b/tests/cli/train_crew_test.py index 9d0d3d4a7..036dd7c2f 100644 --- a/tests/cli/train_crew_test.py +++ b/tests/cli/train_crew_test.py @@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew @mock.patch("crewai.cli.train_crew.subprocess.run") def test_train_crew_positive_iterations(mock_subprocess_run): - # Arrange n_iterations = 5 mock_subprocess_run.return_value = subprocess.CompletedProcess( args=["poetry", "run", "train", str(n_iterations)], @@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run): stderr="", ) - # Act - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") - # Assert mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", str(n_iterations)], + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], capture_output=False, text=True, check=True, @@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run): @mock.patch("crewai.cli.train_crew.click") def test_train_crew_zero_iterations(click): - train_crew(0) + train_crew(0, "trained_agents_data.pkl") click.echo.assert_called_once_with( "An unexpected error occurred: The number of iterations must be a positive integer.", err=True, @@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click): @mock.patch("crewai.cli.train_crew.click") def test_train_crew_negative_iterations(click): - train_crew(-2) + train_crew(-2, "trained_agents_data.pkl") click.echo.assert_called_once_with( "An unexpected error occurred: The number of iterations must be a positive integer.", err=True, @@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click): output="Error", stderr="Some error occurred", ) - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, ) click.echo.assert_has_calls( [ @@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click): @mock.patch("crewai.cli.train_crew.click") @mock.patch("crewai.cli.train_crew.subprocess.run") def test_train_crew_unexpected_exception(mock_subprocess_run, click): - # Arrange n_iterations = 5 mock_subprocess_run.side_effect = Exception("Unexpected error") - train_crew(n_iterations) + train_crew(n_iterations, "trained_agents_data.pkl") mock_subprocess_run.assert_called_once_with( - ["poetry", "run", "train", "5"], capture_output=False, text=True, check=True + ["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, ) click.echo.assert_called_once_with( "An unexpected error occurred: Unexpected error", err=True diff --git a/tests/crew_test.py b/tests/crew_test.py index 0c49d51f2..be7ec1da3 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pydantic_core import pytest + from crewai.agent import Agent from crewai.agents.cache import CacheHandler from crewai.crew import Crew @@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff): agents=[researcher, writer], tasks=[task], ) - crew.train(n_iterations=2, inputs={"topic": "AI"}) + crew.train( + n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl" + ) task_evaluator.assert_has_calls( [ mock.call(researcher), @@ -1890,7 +1893,7 @@ def test__setup_for_training(): for agent in agents: assert agent.allow_delegation is True - crew._setup_for_training() + crew._setup_for_training("trained_agents_data.pkl") assert crew._train is True assert task.human_input is True