mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: add ability to train on custom file
This commit is contained in:
@@ -50,10 +50,17 @@ def version(tools):
|
||||
default=5,
|
||||
help="Number of iterations to train the crew",
|
||||
)
|
||||
def train(n_iterations: int):
|
||||
@click.option(
|
||||
"-f",
|
||||
"--filename",
|
||||
type=str,
|
||||
default="trained_agents_data.pkl",
|
||||
help="Path to a custom file for training",
|
||||
)
|
||||
def train(n_iterations: int, filename: str):
|
||||
"""Train the crew."""
|
||||
click.echo(f"Training the crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations)
|
||||
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||
train_crew(n_iterations, filename)
|
||||
|
||||
|
||||
@crewai.command()
|
||||
|
||||
@@ -25,7 +25,7 @@ def train():
|
||||
"topic": "AI LLMs"
|
||||
}
|
||||
try:
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
|
||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"An error occurred while training the crew: {e}")
|
||||
|
||||
@@ -3,14 +3,14 @@ import subprocess
|
||||
import click
|
||||
|
||||
|
||||
def train_crew(n_iterations: int) -> None:
|
||||
def train_crew(n_iterations: int, filename: str) -> None:
|
||||
"""
|
||||
Train the crew by running a command in the Poetry environment.
|
||||
|
||||
Args:
|
||||
n_iterations (int): The number of iterations to train the crew.
|
||||
"""
|
||||
command = ["poetry", "run", "train", str(n_iterations)]
|
||||
command = ["poetry", "run", "train", str(n_iterations), filename]
|
||||
|
||||
try:
|
||||
if n_iterations <= 0:
|
||||
|
||||
@@ -34,7 +34,6 @@ from crewai.telemetry import Telemetry
|
||||
from crewai.tools.agent_tools import AgentTools
|
||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||
from crewai.utilities.constants import (
|
||||
TRAINED_AGENTS_DATA_FILE,
|
||||
TRAINING_DATA_FILE,
|
||||
)
|
||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||
@@ -386,7 +385,7 @@ class Crew(BaseModel):
|
||||
del task_config["agent"]
|
||||
return Task(**task_config, agent=task_agent)
|
||||
|
||||
def _setup_for_training(self) -> None:
|
||||
def _setup_for_training(self, filename: str) -> None:
|
||||
"""Sets up the crew for training."""
|
||||
self._train = True
|
||||
|
||||
@@ -397,11 +396,13 @@ class Crew(BaseModel):
|
||||
agent.allow_delegation = False
|
||||
|
||||
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
|
||||
CrewTrainingHandler(filename).initialize_file()
|
||||
|
||||
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
||||
def train(
|
||||
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||
) -> None:
|
||||
"""Trains the crew for a given number of iterations."""
|
||||
self._setup_for_training()
|
||||
self._setup_for_training(filename)
|
||||
|
||||
for n_iteration in range(n_iterations):
|
||||
self._train_iteration = n_iteration
|
||||
@@ -414,7 +415,7 @@ class Crew(BaseModel):
|
||||
training_data=training_data, agent_id=str(agent.id)
|
||||
)
|
||||
|
||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
|
||||
CrewTrainingHandler(filename).save_trained_data(
|
||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import os
|
||||
import pickle
|
||||
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -32,14 +30,16 @@ class PickleHandler:
|
||||
Parameters:
|
||||
- file_name (str): The name of the file for saving and loading data.
|
||||
"""
|
||||
if not file_name.endswith(".pkl"):
|
||||
file_name += ".pkl"
|
||||
|
||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||
|
||||
def initialize_file(self) -> None:
|
||||
"""
|
||||
Initialize the file with an empty dictionary if it does not exist or is empty.
|
||||
Initialize the file with an empty dictionary and overwrite any existing data.
|
||||
"""
|
||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
||||
self.save({}) # Save an empty dictionary to initialize the file
|
||||
self.save({})
|
||||
|
||||
def save(self, data) -> None:
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user