feat: add ability to train on custom file

This commit is contained in:
Eduardo Chiarotti
2024-08-08 19:30:45 -03:00
parent 217f5fc5ac
commit aa8640c086
5 changed files with 25 additions and 17 deletions

View File

@@ -50,10 +50,17 @@ def version(tools):
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@crewai.command()

View File

@@ -25,7 +25,7 @@ def train():
"topic": "AI LLMs"
}
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")

View File

@@ -3,14 +3,14 @@ import subprocess
import click
def train_crew(n_iterations: int) -> None:
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
command = ["poetry", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:

View File

@@ -34,7 +34,6 @@ from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import (
TRAINED_AGENTS_DATA_FILE,
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
@@ -386,7 +385,7 @@ class Crew(BaseModel):
del task_config["agent"]
return Task(**task_config, agent=task_agent)
def _setup_for_training(self) -> None:
def _setup_for_training(self, filename: str) -> None:
"""Sets up the crew for training."""
self._train = True
@@ -397,11 +396,13 @@ class Crew(BaseModel):
agent.allow_delegation = False
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
CrewTrainingHandler(filename).initialize_file()
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training()
self._setup_for_training(filename)
for n_iteration in range(n_iterations):
self._train_iteration = n_iteration
@@ -414,7 +415,7 @@ class Crew(BaseModel):
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
)

View File

@@ -1,7 +1,5 @@
import os
import pickle
from datetime import datetime
@@ -32,14 +30,16 @@ class PickleHandler:
Parameters:
- file_name (str): The name of the file for saving and loading data.
"""
if not file_name.endswith(".pkl"):
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary if it does not exist or is empty.
Initialize the file with an empty dictionary and overwrite any existing data.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
self.save({}) # Save an empty dictionary to initialize the file
self.save({})
def save(self, data) -> None:
"""