feat: add ability to train on custom file (#1161)

* feat: add ability to train on custom file

* feat: add pkl file validation

* feat: fix tests

* feat: fix tests

* feat: fix tests
This commit is contained in:
Eduardo Chiarotti
2024-08-09 19:41:58 -03:00
committed by GitHub
parent 62f5b2fb2e
commit 51ee483e9d
8 changed files with 54 additions and 35 deletions

View File

@@ -60,10 +60,17 @@ def version(tools):
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@crewai.command()

View File

@@ -25,7 +25,7 @@ def train():
"topic": "AI LLMs"
}
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")

View File

@@ -3,19 +3,22 @@ import subprocess
import click
def train_crew(n_iterations: int) -> None:
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
command = ["poetry", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl")
result = subprocess.run(command, capture_output=False, text=True, check=True)
if result.stderr:

View File

@@ -34,7 +34,9 @@ from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.constants import (
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
@@ -388,7 +390,7 @@ class Crew(BaseModel):
del task_config["agent"]
return Task(**task_config, agent=task_agent)
def _setup_for_training(self) -> None:
def _setup_for_training(self, filename: str) -> None:
"""Sets up the crew for training."""
self._train = True
@@ -399,11 +401,13 @@ class Crew(BaseModel):
agent.allow_delegation = False
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
CrewTrainingHandler(filename).initialize_file()
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training()
self._setup_for_training(filename)
for n_iteration in range(n_iterations):
self._train_iteration = n_iteration
@@ -416,7 +420,7 @@ class Crew(BaseModel):
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
)

View File

@@ -1,7 +1,5 @@
import os
import pickle
from datetime import datetime
@@ -32,14 +30,16 @@ class PickleHandler:
Parameters:
- file_name (str): The name of the file for saving and loading data.
"""
if not file_name.endswith(".pkl"):
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary if it does not exist or is empty.
Initialize the file with an empty dictionary and overwrite any existing data.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
self.save({}) # Save an empty dictionary to initialize the file
self.save({})
def save(self, data) -> None:
"""