mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
feat: add ability to train on custom file (#1161)
* feat: add ability to train on custom file * feat: add pkl file validation * feat: fix tests * feat: fix tests * feat: fix tests
This commit is contained in:
committed by
GitHub
parent
62f5b2fb2e
commit
51ee483e9d
@@ -15,18 +15,18 @@ def runner():
|
||||
def test_train_default_iterations(train_crew, runner):
|
||||
result = runner.invoke(train)
|
||||
|
||||
train_crew.assert_called_once_with(5)
|
||||
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the crew for 5 iterations" in result.output
|
||||
assert "Training the Crew for 5 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.train_crew")
|
||||
def test_train_custom_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "10"])
|
||||
|
||||
train_crew.assert_called_once_with(10)
|
||||
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the crew for 10 iterations" in result.output
|
||||
assert "Training the Crew for 10 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai.cli.cli.train_crew")
|
||||
|
||||
@@ -6,7 +6,6 @@ from crewai.cli.train_crew import train_crew
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||
def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
# Arrange
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
||||
args=["poetry", "run", "train", str(n_iterations)],
|
||||
@@ -15,12 +14,10 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
stderr="",
|
||||
)
|
||||
|
||||
# Act
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
# Assert
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", str(n_iterations)],
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
@@ -29,7 +26,7 @@ def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
def test_train_crew_zero_iterations(click):
|
||||
train_crew(0)
|
||||
train_crew(0, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
@@ -38,7 +35,7 @@ def test_train_crew_zero_iterations(click):
|
||||
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
def test_train_crew_negative_iterations(click):
|
||||
train_crew(-2)
|
||||
train_crew(-2, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
@@ -55,10 +52,13 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
||||
output="Error",
|
||||
stderr="Some error occurred",
|
||||
)
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_has_calls(
|
||||
[
|
||||
@@ -74,13 +74,15 @@ def test_train_crew_called_process_error(mock_subprocess_run, click):
|
||||
@mock.patch("crewai.cli.train_crew.click")
|
||||
@mock.patch("crewai.cli.train_crew.subprocess.run")
|
||||
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
|
||||
# Arrange
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
||||
train_crew(n_iterations)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "run", "train", "5"], capture_output=False, text=True, check=True
|
||||
["poetry", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: Unexpected error", err=True
|
||||
|
||||
@@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pydantic_core
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.cache import CacheHandler
|
||||
from crewai.crew import Crew
|
||||
@@ -1806,7 +1807,9 @@ def test_crew_train_success(task_evaluator, crew_training_handler, kickoff):
|
||||
agents=[researcher, writer],
|
||||
tasks=[task],
|
||||
)
|
||||
crew.train(n_iterations=2, inputs={"topic": "AI"})
|
||||
crew.train(
|
||||
n_iterations=2, inputs={"topic": "AI"}, filename="trained_agents_data.pkl"
|
||||
)
|
||||
task_evaluator.assert_has_calls(
|
||||
[
|
||||
mock.call(researcher),
|
||||
@@ -1890,7 +1893,7 @@ def test__setup_for_training():
|
||||
for agent in agents:
|
||||
assert agent.allow_delegation is True
|
||||
|
||||
crew._setup_for_training()
|
||||
crew._setup_for_training("trained_agents_data.pkl")
|
||||
|
||||
assert crew._train is True
|
||||
assert task.human_input is True
|
||||
|
||||
Reference in New Issue
Block a user