mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
feat: add ability to train on custom file (#1161)
* feat: add ability to train on custom file * feat: add pkl file validation * feat: fix tests * feat: fix tests * feat: fix tests
This commit is contained in:
committed by
GitHub
parent
62f5b2fb2e
commit
51ee483e9d
@@ -60,10 +60,17 @@ def version(tools):
|
||||
default=5,
|
||||
help="Number of iterations to train the crew",
|
||||
)
|
||||
def train(n_iterations: int):
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations)
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -25,7 +25,7 @@ def train():
|
||||
"topic": "AI LLMs"
|
||||
}
|
||||
try:
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
|
||||
@@ -3,19 +3,22 @@ import subprocess
|
||||
import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int) -> None:
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the Poetry environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
"""
|
||||
command = ["poetry", "run", "train", str(n_iterations)]
|
||||
command = ["poetry", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
raise ValueError("The number of iterations must be a positive integer.")
|
||||
|
||||
if not filename.endswith(".pkl"):
|
||||
raise ValueError("The filename must not end with .pkl")
|
||||
|
||||
result = subprocess.run(command, capture_output=False, text=True, check=True)
|
||||
|
||||
if result.stderr:
|
||||
|
||||
Reference in New Issue
Block a user