feat: add ability to train on custom file (#1161)

* feat: add ability to train on custom file

* feat: add pkl file validation

* feat: fix tests

* feat: fix tests

* feat: fix tests
This commit is contained in:
Eduardo Chiarotti
2024-08-09 19:41:58 -03:00
committed by GitHub
parent 62f5b2fb2e
commit 51ee483e9d
8 changed files with 54 additions and 35 deletions

View File

@@ -15,18 +15,18 @@ def runner():
def test_train_default_iterations(train_crew, runner):
result = runner.invoke(train)
train_crew.assert_called_once_with(5)
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 5 iterations" in result.output
assert "Training the Crew for 5 iterations" in result.output
@mock.patch("crewai.cli.cli.train_crew")
def test_train_custom_iterations(train_crew, runner):
result = runner.invoke(train, ["--n_iterations", "10"])
train_crew.assert_called_once_with(10)
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
assert result.exit_code == 0
assert "Training the crew for 10 iterations" in result.output
assert "Training the Crew for 10 iterations" in result.output
@mock.patch("crewai.cli.cli.train_crew")

View File

@@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_positive_iterations(mock_subprocess_run):
# Arrange
n_iterations = 5
mock_subprocess_run.return_value = subprocess.CompletedProcess(
args=["poetry", "run", "train", str(n_iterations)],
@@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
stderr="",
)
# Act
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
# Assert
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", str(n_iterations)],
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
@@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_zero_iterations(click):
train_crew(0)
train_crew(0, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
@@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):
@mock.patch("crewai.cli.train_crew.click")
def test_train_crew_negative_iterations(click):
train_crew(-2)
train_crew(-2, "trained_agents_data.pkl")
click.echo.assert_called_once_with(
"An unexpected error occurred: The number of iterations must be a positive integer.",
err=True,
@@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
output="Error",
stderr="Some error occurred",
)
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_has_calls(
[
@@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
@mock.patch("crewai.cli.train_crew.click")
@mock.patch("crewai.cli.train_crew.subprocess.run")
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
# Arrange
n_iterations = 5
mock_subprocess_run.side_effect = Exception("Unexpected error")
train_crew(n_iterations)
train_crew(n_iterations, "trained_agents_data.pkl")
mock_subprocess_run.assert_called_once_with(
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
capture_output=False,
text=True,
check=True,
)
click.echo.assert_called_once_with(
"An unexpected error occurred: Unexpected error", err=True

View File

@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
import pydantic_core
import pytest
from crewai.agent import Agent
from crewai.agents.cache import CacheHandler
from crewai.crew import Crew
@@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
agents=[researcher, writer],
tasks=[task],
)
crew.train(n_iterations=2, inputs={"topic": "AI"})
crew.train(
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
)
task_evaluator.assert_has_calls(
[
mock.call(researcher),
@@ -1890,7 +1893,7 @@ def test__setup_for_training():
for agent in agents:
assert agent.allow_delegation is True
crew._setup_for_training()
crew._setup_for_training("trained_agents_data.pkl")
assert crew._train is True
assert task.human_input is True