mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-25 16:58:29 +00:00
* fix: fix crewai-tools cli command * feat: add crewai train CLI command * feat: add the tests * fix: fix typing hinting issue on code * fix: test.yml * fix: fix test * fix: removed fix since it didnt changed the test
60 lines
1.6 KiB
Python
60 lines
1.6 KiB
Python
from unittest import mock
|
|
|
|
import pytest
|
|
from click.testing import CliRunner
|
|
|
|
from crewai.cli.cli import train, version
|
|
|
|
|
|
@pytest.fixture
|
|
def runner():
|
|
return CliRunner()
|
|
|
|
|
|
@mock.patch("crewai.cli.cli.train_crew")
|
|
def test_train_default_iterations(train_crew, runner):
|
|
result = runner.invoke(train)
|
|
|
|
train_crew.assert_called_once_with(5)
|
|
assert result.exit_code == 0
|
|
assert "Training the crew for 5 iterations" in result.output
|
|
|
|
|
|
@mock.patch("crewai.cli.cli.train_crew")
|
|
def test_train_custom_iterations(train_crew, runner):
|
|
result = runner.invoke(train, ["--n_iterations", "10"])
|
|
|
|
train_crew.assert_called_once_with(10)
|
|
assert result.exit_code == 0
|
|
assert "Training the crew for 10 iterations" in result.output
|
|
|
|
|
|
@mock.patch("crewai.cli.cli.train_crew")
|
|
def test_train_invalid_string_iterations(train_crew, runner):
|
|
result = runner.invoke(train, ["--n_iterations", "invalid"])
|
|
|
|
train_crew.assert_not_called()
|
|
assert result.exit_code == 2
|
|
assert (
|
|
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
|
|
in result.output
|
|
)
|
|
|
|
|
|
def test_version_command(runner):
|
|
result = runner.invoke(version)
|
|
|
|
assert result.exit_code == 0
|
|
assert "crewai version:" in result.output
|
|
|
|
|
|
def test_version_command_with_tools(runner):
|
|
result = runner.invoke(version, ["--tools"])
|
|
|
|
assert result.exit_code == 0
|
|
assert "crewai version:" in result.output
|
|
assert (
|
|
"crewai tools version:" in result.output
|
|
or "crewai tools not installed" in result.output
|
|
)
|