Compare commits

...

7 Commits

Author SHA1 Message Date
Eduardo Chiarotti
4b8df71ca1 Merge branch 'main' into feat/add-custom-training-file 2024-08-09 19:28:09 -03:00
Eduardo Chiarotti
cc187a23d7 Merge branch 'main' into feat/add-custom-training-file 2024-08-09 11:01:43 -03:00
Eduardo Chiarotti
540e328f06 feat: fix tests 2024-08-08 21:06:10 -03:00
Eduardo Chiarotti
87d4c5f092 feat: fix tests 2024-08-08 21:01:02 -03:00
Eduardo Chiarotti
45c16cfa6b feat: fix tests 2024-08-08 20:56:25 -03:00
Eduardo Chiarotti
9232ac3e3f feat: add pkl file validation 2024-08-08 19:46:20 -03:00
Eduardo Chiarotti
aa8640c086 feat: add ability to train on custom file 2024-08-08 19:30:45 -03:00
8 changed files with 54 additions and 35 deletions

View File

@@ -60,10 +60,17 @@ def version(tools):
default=5,
help="Number of iterations to train the crew",
)
def train(n_iterations: int):
@click.option(
"-f",
"--filename",
type=str,
default="trained_agents_data.pkl",
help="Path to a custom file for training",
)
def train(n_iterations: int, filename: str):
"""Train the crew."""
click.echo(f"Training the crew for {n_iterations} iterations")
train_crew(n_iterations)
click.echo(f"Training the Crew for {n_iterations} iterations")
train_crew(n_iterations, filename)
@crewai.command()

View File

@@ -25,7 +25,7 @@ def train():
"topic": "AI LLMs"
}
try:
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), inputs=inputs)
{{crew_name}}Crew().crew().train(n_iterations=int(sys.argv[1]), filename=sys.argv[2], inputs=inputs)
except Exception as e:
raise Exception(f"An error occurred while training the crew: {e}")

View File

@@ -3,19 +3,22 @@ import subprocess
import click
def train_crew(n_iterations: int) -> None:
def train_crew(n_iterations: int, filename: str) -> None:
"""
Train the crew by running a command in the Poetry environment.
Args:
n_iterations (int): The number of iterations to train the crew.
"""
command = ["poetry", "run", "train", str(n_iterations)]
command = ["poetry", "run", "train", str(n_iterations), filename]
try:
if n_iterations <= 0:
raise ValueError("The number of iterations must be a positive integer.")
if not filename.endswith(".pkl"):
raise ValueError("The filename must not end with .pkl")
result = subprocess.run(command, capture_output=False, text=True, check=True)
if result.stderr:

View File

@@ -34,7 +34,9 @@ from crewai.telemetry import Telemetry
from crewai.tools.agent_tools import AgentTools
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.constants import (
TRAINING_DATA_FILE,
)
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.formatter import (
@@ -388,7 +390,7 @@ class Crew(BaseModel):
del task_config["agent"]
return Task(**task_config, agent=task_agent)
def _setup_for_training(self) -> None:
def _setup_for_training(self, filename: str) -> None:
"""Sets up the crew for training."""
self._train = True
@@ -399,11 +401,13 @@ class Crew(BaseModel):
agent.allow_delegation = False
CrewTrainingHandler(TRAINING_DATA_FILE).initialize_file()
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).initialize_file()
CrewTrainingHandler(filename).initialize_file()
def train(self, n_iterations: int, inputs: Optional[Dict[str, Any]] = {}) -> None:
def train(
self, n_iterations: int, filename: str, inputs: Optional[Dict[str, Any]] = {}
) -> None:
"""Trains the crew for a given number of iterations."""
self._setup_for_training()
self._setup_for_training(filename)
for n_iteration in range(n_iterations):
self._train_iteration = n_iteration
@@ -416,7 +420,7 @@ class Crew(BaseModel):
training_data=training_data, agent_id=str(agent.id)
)
CrewTrainingHandler(TRAINED_AGENTS_DATA_FILE).save_trained_data(
CrewTrainingHandler(filename).save_trained_data(
agent_id=str(agent.role), trained_data=result.model_dump()
)

View File

@@ -1,7 +1,5 @@
import os
import pickle
from datetime import datetime
@@ -32,14 +30,16 @@ class PickleHandler:
Parameters:
- file_name (str): The name of the file for saving and loading data.
"""
if not file_name.endswith(".pkl"):
file_name += ".pkl"
self.file_path = os.path.join(os.getcwd(), file_name)
def initialize_file(self) -> None:
"""
Initialize the file with an empty dictionary if it does not exist or is empty.
Initialize the file with an empty dictionary and overwrite any existing data.
"""
if not os.path.exists(self.file_path) or os.path.getsize(self.file_path) == 0:
self.save({}) # Save an empty dictionary to initialize the file
self.save({})
def save(self, data) -> None:
"""

View File

@@ -15,18 +15,18 @@ def runner():
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)
train_crew.assert_called_once_with(5)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 5 iterations" in result.output
assert "Training the Crew for 5 iterations" in result.output
@mock.patch("crewai.cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10)
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 10 iterations" in result.output
assert "Training the Crew for 10 iterations" in result.output
@mock.patch("crewai.cli.cli.train_crew")

View File

@@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_positive_iterations(mock_subprocess_run):
# Arrange
n_iterations = 5
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=["poetry", "run", "train", str(n_iterations)],
@@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
stderr="",
)
# Act
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
# Assert
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", str(n_iterations)],
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
@@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_zero_iterations(click):
train_crew(0)
train_crew(0, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
@@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_negative_iterations(click):
train_crew(-2)
train_crew(-2, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
@@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
output="Error",
stderr="Some error occurred",
)
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_has_calls(
[
@@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
# Arrange
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True

View File

@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
import pydantic_core
import pytest
from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.crew import Crew
@@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
agents=[researcher, writer],
tasks=[task],
)
crew.train(n_iterations=2, inputs={"topic": "AI"})
crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
)
task_evaluator.assert_has_calls(
[
mock.call(researcher),
@@ -1890,7 +1893,7 @@ def test__setup_for_training():
for agent in agents:
assert agent.allow_delegation is True
crew._setup_for_training()
crew._setup_for_training("trained_agents_data.pkl")
assert crew._train is True
assert task.human_input is True