feat: add ability to train on custom file (#1161)

* feat: add ability to train on custom file

* feat: add pkl file validation

* feat: fix tests

* feat: fix tests

* feat: fix tests
This commit is contained in:
Eduardo Chiarotti
2024-08-09 19:41:58 -03:00
committed by GitHub
parent 62f5b2fb2e
commit 51ee483e9d
8 changed files with 54 additions and 35 deletions

View File

@@ -60,10 +60,17 @@ def version(tools):
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@crewai.command()

View File

@@ -25,7 +25,7 @@ def train():
"topic": "AI LLMs"
}
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")

View File

@@ -3,19 +3,22 @@ import subprocess
import click
def train_crew(n_iterations: int) -> None:
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
command = ["poetry", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl")
result = subprocess.run(command, capture_output=False, text=True, check=True)
if result.stderr: