mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 00:28:31 +00:00
feat: add ability to train on custom file
This commit is contained in:
@@ -50,10 +50,17 @@ def version(tools):
|
|||||||
default=5,
|
default=5,
|
||||||
help="Number of iterations to train the crew",
|
help="Number of iterations to train the crew",
|
||||||
)
|
)
|
||||||
def train(n_iterations: int):
|
@click.option(
|
||||||
|
"-f",
|
||||||
|
"--filename",
|
||||||
|
type=str,
|
||||||
|
default="trained_agents_data.pkl",
|
||||||
|
help="Path to a custom file for training",
|
||||||
|
)
|
||||||
|
def train(n_iterations: int, filename: str):
|
||||||
"""Train the crew."""
|
"""Train the crew."""
|
||||||
click.echo(f"Training the crew for {n_iterations} iterations")
|
click.echo(f"Training the Crew for {n_iterations} iterations")
|
||||||
train_crew(n_iterations)
|
train_crew(n_iterations, filename)
|
||||||
|
|
||||||
|
|
||||||
@crewai.command()
|
@crewai.command()
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ def train():
|
|||||||
"topic": "AI LLMs"
|
"topic": "AI LLMs"
|
||||||
}
|
}
|
||||||
try:
|
try:
|
||||||
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
|
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"An error occurred while training the crew: {e}")
|
raise Exception(f"An error occurred while training the crew: {e}")
|
||||||
|
|||||||
@@ -3,14 +3,14 @@ import subprocess
|
|||||||
import click
|
import click
|
||||||
|
|
||||||
|
|
||||||
def train_crew(n_iterations: int) -> None:
|
def train_crew(n_iterations: int, filename: str) -> None:
|
||||||
"""
|
"""
|
||||||
Train the crew by running a command in the Poetry environment.
|
Train the crew by running a command in the Poetry environment.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
n_iterations (int): The number of iterations to train the crew.
|
n_iterations (int): The number of iterations to train the crew.
|
||||||
"""
|
"""
|
||||||
command = ["poetry", "run", "train", str(n_iterations)]
|
command = ["poetry", "run", "train", str(n_iterations), filename]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if n_iterations <= 0:
|
if n_iterations <= 0:
|
||||||
|
|||||||
@@ -34,7 +34,6 @@ from crewai.telemetry import Telemetry
|
|||||||
from crewai.tools.agent_tools import AgentTools
|
from crewai.tools.agent_tools import AgentTools
|
||||||
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
from crewai.utilities import I18N, FileHandler, Logger, RPMController
|
||||||
from crewai.utilities.constants import (
|
from crewai.utilities.constants import (
|
||||||
TRAINED_AGENTS_DATA_FILE,
|
|
||||||
TRAINING_DATA_FILE,
|
TRAINING_DATA_FILE,
|
||||||
)
|
)
|
||||||
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
|
||||||
@@ -386,7 +385,7 @@ class Crew(BaseModel):
|
|||||||
del task_config["agent"]
|
del task_config["agent"]
|
||||||
return Task(**task_config, agent=task_agent)
|
return Task(**task_config, agent=task_agent)
|
||||||
|
|
||||||
def _setup_for_training(self) -> None:
|
def _setup_for_training(self, filename: str) -> None:
|
||||||
"""Sets up the crew for training."""
|
"""Sets up the crew for training."""
|
||||||
self._train = True
|
self._train = True
|
||||||
|
|
||||||
@@ -397,11 +396,13 @@ class Crew(BaseModel):
|
|||||||
agent.allow_delegation = False
|
agent.allow_delegation = False
|
||||||
|
|
||||||
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
|
||||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
|
CrewTrainingHandler(filename).initialize_file()
|
||||||
|
|
||||||
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
|
def train(
|
||||||
|
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
|
||||||
|
) -> None:
|
||||||
"""Trains the crew for a given number of iterations."""
|
"""Trains the crew for a given number of iterations."""
|
||||||
self._setup_for_training()
|
self._setup_for_training(filename)
|
||||||
|
|
||||||
for n_iteration in range(n_iterations):
|
for n_iteration in range(n_iterations):
|
||||||
self._train_iteration = n_iteration
|
self._train_iteration = n_iteration
|
||||||
@@ -414,7 +415,7 @@ class Crew(BaseModel):
|
|||||||
training_data=training_data, agent_id=str(agent.id)
|
training_data=training_data, agent_id=str(agent.id)
|
||||||
)
|
)
|
||||||
|
|
||||||
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
|
CrewTrainingHandler(filename).save_trained_data(
|
||||||
agent_id=str(agent.role), trained_data=result.model_dump()
|
agent_id=str(agent.role), trained_data=result.model_dump()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,5 @@
|
|||||||
import os
|
import os
|
||||||
import pickle
|
import pickle
|
||||||
|
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
|
||||||
|
|
||||||
@@ -32,14 +30,16 @@ class PickleHandler:
|
|||||||
Parameters:
|
Parameters:
|
||||||
- file_name (str): The name of the file for saving and loading data.
|
- file_name (str): The name of the file for saving and loading data.
|
||||||
"""
|
"""
|
||||||
|
if not file_name.endswith(".pkl"):
|
||||||
|
file_name += ".pkl"
|
||||||
|
|
||||||
self.file_path = os.path.join(os.getcwd(), file_name)
|
self.file_path = os.path.join(os.getcwd(), file_name)
|
||||||
|
|
||||||
def initialize_file(self) -> None:
|
def initialize_file(self) -> None:
|
||||||
"""
|
"""
|
||||||
Initialize the file with an empty dictionary if it does not exist or is empty.
|
Initialize the file with an empty dictionary and overwrite any existing data.
|
||||||
"""
|
"""
|
||||||
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
|
self.save({})
|
||||||
self.save({}) # Save an empty dictionary to initialize the file
|
|
||||||
|
|
||||||
def save(self, data) -> None:
|
def save(self, data) -> None:
|
||||||
"""
|
"""
|
||||||
|
|||||||
Reference in New Issue
Block a user