From aa8640c0860329577f2c3e72682866c71836d505 Mon Sep 17 00:00:00 2001 From: Eduardo Chiarotti Date: Thu, 8 Aug 2024 19:30:45 -0300 Subject: [PATCH] feat: add ability to train on custom file --- src/crewai/cli/cli.py | 13 ++++++++++--- src/crewai/cli/templates/main.py | 2 +- src/crewai/cli/train_crew.py | 4 ++-- src/crewai/crew.py | 13 +++++++------ src/crewai/utilities/file_handler.py | 10 +++++----- 5 files changed, 25 insertions(+), 17 deletions(-) diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 5ae9feb03..5d9ec2843 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -50,10 +50,17 @@ def version(tools): default=5, help="Number of iterations to train the crew", ) -def train(n_iterations: int): +@click.option( + "-f", + "--filename", + type=str, + default="trained_agents_data.pkl", + help="Path to a custom file for training", +) +def train(n_iterations: int, filename: str): """Train the crew.""" - click.echo(f"Training the crew for {n_iterations} iterations") - train_crew(n_iterations) + click.echo(f"Training the Crew for {n_iterations} iterations") + train_crew(n_iterations, filename) @crewai.command() diff --git a/src/crewai/cli/templates/main.py b/src/crewai/cli/templates/main.py index 94ac35b93..86beeeaca 100644 --- a/src/crewai/cli/templates/main.py +++ b/src/crewai/cli/templates/main.py @@ -25,7 +25,7 @@ def train(): "topic": "AI LLMs" } try: - {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs) + {{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs) except Exception as e: raise Exception(f"An error occurred while training the crew: {e}") diff --git a/src/crewai/cli/train_crew.py b/src/crewai/cli/train_crew.py index cd880db5d..369cd558d 100644 --- a/src/crewai/cli/train_crew.py +++ b/src/crewai/cli/train_crew.py @@ -3,14 +3,14 @@ import subprocess import click -def train_crew(n_iterations: int) -> None: +def train_crew(n_iterations: int, filename: str) -> None: """ Train the crew by running a command in the Poetry environment. Args: n_iterations (int): The number of iterations to train the crew. """ - command = ["poetry", "run", "train", str(n_iterations)] + command = ["poetry", "run", "train", str(n_iterations), filename] try: if n_iterations <= 0: diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 2c84e3c4b..67c9f6772 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -34,7 +34,6 @@ from crewai.telemetry import Telemetry from crewai.tools.agent_tools import AgentTools from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import ( - TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE, ) from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator @@ -386,7 +385,7 @@ class Crew(BaseModel): del task_config["agent"] return Task(**task_config, agent=task_agent) - def _setup_for_training(self) -> None: + def _setup_for_training(self, filename: str) -> None: """Sets up the crew for training.""" self._train = True @@ -397,11 +396,13 @@ class Crew(BaseModel): agent.allow_delegation = False CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file() - CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file() + CrewTrainingHandler(filename).initialize_file() - def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None: + def train( + self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {} + ) -> None: """Trains the crew for a given number of iterations.""" - self._setup_for_training() + self._setup_for_training(filename) for n_iteration in range(n_iterations): self._train_iteration = n_iteration @@ -414,7 +415,7 @@ class Crew(BaseModel): training_data=training_data, agent_id=str(agent.id) ) - CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data( + CrewTrainingHandler(filename).save_trained_data( agent_id=str(agent.role), trained_data=result.model_dump() ) diff --git a/src/crewai/utilities/file_handler.py b/src/crewai/utilities/file_handler.py index 68c33241d..1125cae4e 100644 --- a/src/crewai/utilities/file_handler.py +++ b/src/crewai/utilities/file_handler.py @@ -1,7 +1,5 @@ import os import pickle - - from datetime import datetime @@ -32,14 +30,16 @@ class PickleHandler: Parameters: - file_name (str): The name of the file for saving and loading data. """ + if not file_name.endswith(".pkl"): + file_name += ".pkl" + self.file_path = os.path.join(os.getcwd(), file_name) def initialize_file(self) -> None: """ - Initialize the file with an empty dictionary if it does not exist or is empty. + Initialize the file with an empty dictionary and overwrite any existing data. """ - if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0: - self.save({}) # Save an empty dictionary to initialize the file + self.save({}) def save(self, data) -> None: """