Add --force option to crewai tool publish (#1383)

This commit adds an option to bypass Git remote validations when
publishing tools.
This commit is contained in:
Vini Brasil
2024-10-04 11:02:50 -03:00
committed by GitHub
parent d063ed3014
commit 5dee13e078
3 changed files with 316 additions and 271 deletions

View File

@@ -276,12 +276,13 @@ def tool_install(handle: str):
@tool.command(name="publish") @tool.command(name="publish")
@click.option("--force", is_flag=True, show_default=True, default=False, help="Bypasses Git remote validations")
@click.option("--public", "is_public", flag_value=True, default=False) @click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False) @click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool): def tool_publish(is_public: bool, force: bool):
tool_cmd = ToolCommand() tool_cmd = ToolCommand()
tool_cmd.login() tool_cmd.login()
tool_cmd.publish(is_public) tool_cmd.publish(is_public, force)
@crewai.group() @crewai.group()

View File

@@ -59,8 +59,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
finally: finally:
os.chdir(old_directory) os.chdir(old_directory)
def publish(self, is_public: bool): def publish(self, is_public: bool, force: bool = False):
if not git.Repository().is_synced(): if not git.Repository().is_synced() and not force:
console.print( console.print(
"[bold red]Failed to publish tool.[/bold red]\n" "[bold red]Failed to publish tool.[/bold red]\n"
"Local changes need to be resolved before publishing. Please do the following:\n" "Local changes need to be resolved before publishing. Please do the following:\n"

View File

@@ -1,16 +1,16 @@
from contextlib import contextmanager
import tempfile import tempfile
import unittest import unittest
import unittest.mock import unittest.mock
import os import os
from contextlib import contextmanager
from pytest import raises
from crewai.cli.tools.main import ToolCommand from crewai.cli.tools.main import ToolCommand
from io import StringIO from io import StringIO
from unittest.mock import patch, MagicMock from unittest.mock import patch, MagicMock
@contextmanager
class TestToolCommand(unittest.TestCase): def in_temp_dir():
@contextmanager
def in_temp_dir(self):
original_dir = os.getcwd() original_dir = os.getcwd()
with tempfile.TemporaryDirectory() as temp_dir: with tempfile.TemporaryDirectory() as temp_dir:
os.chdir(temp_dir) os.chdir(temp_dir)
@@ -19,9 +19,9 @@ class TestToolCommand(unittest.TestCase):
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
@patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.tools.main.subprocess.run")
def test_create_success(self, mock_subprocess): def test_create_success(mock_subprocess):
with self.in_temp_dir(): with in_temp_dir():
tool_command = ToolCommand() tool_command = ToolCommand()
with patch.object(tool_command, "login") as mock_login, patch( with patch.object(tool_command, "login") as mock_login, patch(
@@ -30,33 +30,28 @@ class TestToolCommand(unittest.TestCase):
tool_command.create("test-tool") tool_command.create("test-tool")
output = fake_out.getvalue() output = fake_out.getvalue()
self.assertTrue(os.path.isdir("test_tool")) assert os.path.isdir("test_tool")
assert os.path.isfile(os.path.join("test_tool", "README.md"))
self.assertTrue(os.path.isfile(os.path.join("test_tool", "README.md"))) assert os.path.isfile(os.path.join("test_tool", "pyproject.toml"))
self.assertTrue(os.path.isfile(os.path.join("test_tool", "pyproject.toml"))) assert os.path.isfile(
self.assertTrue(
os.path.isfile(
os.path.join("test_tool", "src", "test_tool", "__init__.py") os.path.join("test_tool", "src", "test_tool", "__init__.py")
) )
) assert os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py"))
self.assertTrue(
os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py"))
)
with open( with open(
os.path.join("test_tool", "src", "test_tool", "tool.py"), "r" os.path.join("test_tool", "src", "test_tool", "tool.py"), "r"
) as f: ) as f:
content = f.read() content = f.read()
self.assertIn("class TestTool", content) assert "class TestTool" in content
mock_login.assert_called_once() mock_login.assert_called_once()
mock_subprocess.assert_called_once_with(["git", "init"], check=True) mock_subprocess.assert_called_once_with(["git", "init"], check=True)
self.assertIn("Creating custom tool test_tool...", output) assert "Creating custom tool test_tool..." in output
@patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_success(self, mock_get, mock_subprocess_run): def test_install_success(mock_get, mock_subprocess_run):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 200 mock_get_response.status_code = 200
mock_get_response.json.return_value = { mock_get_response.json.return_value = {
@@ -80,10 +75,10 @@ class TestToolCommand(unittest.TestCase):
check=True, check=True,
) )
self.assertIn("Succesfully installed sample-tool", output) assert "Succesfully installed sample-tool" in output
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_tool_not_found(self, mock_get): def test_install_tool_not_found(mock_get):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 404 mock_get_response.status_code = 404
mock_get.return_value = mock_get_response mock_get.return_value = mock_get_response
@@ -91,15 +86,17 @@ class TestToolCommand(unittest.TestCase):
tool_command = ToolCommand() tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit): try:
tool_command.install("non-existent-tool") tool_command.install("non-existent-tool")
except SystemExit:
pass
output = fake_out.getvalue() output = fake_out.getvalue()
mock_get.assert_called_once_with("non-existent-tool") mock_get.assert_called_once_with("non-existent-tool")
self.assertIn("No tool found with this name", output) assert "No tool found with this name" in output
@patch("crewai.cli.plus_api.PlusAPI.get_tool") @patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_api_error(self, mock_get): def test_install_api_error(mock_get):
mock_get_response = MagicMock() mock_get_response = MagicMock()
mock_get_response.status_code = 500 mock_get_response.status_code = 500
mock_get.return_value = mock_get_response mock_get.return_value = mock_get_response
@@ -107,31 +104,37 @@ class TestToolCommand(unittest.TestCase):
tool_command = ToolCommand() tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit): try:
tool_command.install("error-tool") tool_command.install("error-tool")
except SystemExit:
pass
output = fake_out.getvalue() output = fake_out.getvalue()
mock_get.assert_called_once_with("error-tool") mock_get.assert_called_once_with("error-tool")
self.assertIn("Failed to get tool details", output) assert "Failed to get tool details" in output
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") def test_publish_when_not_in_sync(mock_is_synced):
@patch( with patch("sys.stdout", new=StringIO()) as fake_out, \
"crewai.cli.tools.main.get_project_description", return_value="A sample tool" raises(SystemExit):
) tool_command = ToolCommand()
@patch("crewai.cli.tools.main.subprocess.run") tool_command.publish(is_public=True)
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"] assert "Local changes need to be resolved before publishing" in fake_out.getvalue()
)
@patch( @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
@patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
@patch(
"crewai.cli.tools.main.open", "crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open, new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content", read_data=b"sample tarball content",
) )
@patch("crewai.cli.plus_api.PlusAPI.publish_tool") @patch("crewai.cli.plus_api.PlusAPI.publish_tool")
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True) @patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=False)
def test_publish_success( def test_publish_when_not_in_sync_and_force(
self,
mock_is_synced, mock_is_synced,
mock_publish, mock_publish,
mock_open, mock_open,
@@ -140,7 +143,54 @@ class TestToolCommand(unittest.TestCase):
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
): ):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"}
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
tool_command.publish(is_public=True, force=True)
mock_get_project_name.assert_called_with(require=True)
mock_get_project_version.assert_called_with(require=True)
mock_get_project_description.assert_called_with(require=False)
mock_subprocess_run.assert_called_with(
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
check=True,
capture_output=False,
)
mock_open.assert_called_with(unittest.mock.ANY, "rb")
mock_publish.assert_called_with(
handle="sample-tool",
is_public=True,
version="1.0.0",
description="A sample tool",
encoded_file=unittest.mock.ANY,
)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
@patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True)
def test_publish_success(
mock_is_synced,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock() mock_publish_response = MagicMock()
mock_publish_response.status_code = 200 mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"} mock_publish_response.json.return_value = {"handle": "sample-tool"}
@@ -166,23 +216,18 @@ class TestToolCommand(unittest.TestCase):
encoded_file=unittest.mock.ANY, encoded_file=unittest.mock.ANY,
) )
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch( @patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
"crewai.cli.tools.main.get_project_description", return_value="A sample tool" @patch("crewai.cli.tools.main.subprocess.run")
) @patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
@patch("crewai.cli.tools.main.subprocess.run") @patch(
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open", "crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open, new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content", read_data=b"sample tarball content",
) )
@patch("crewai.cli.plus_api.PlusAPI.publish_tool") @patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_failure( def test_publish_failure(
self,
mock_publish, mock_publish,
mock_open, mock_open,
mock_listdir, mock_listdir,
@@ -190,7 +235,7 @@ class TestToolCommand(unittest.TestCase):
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
): ):
mock_publish_response = MagicMock() mock_publish_response = MagicMock()
mock_publish_response.status_code = 422 mock_publish_response.status_code = 422
mock_publish_response.json.return_value = {"name": ["is already taken"]} mock_publish_response.json.return_value = {"name": ["is already taken"]}
@@ -199,31 +244,28 @@ class TestToolCommand(unittest.TestCase):
tool_command = ToolCommand() tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit): try:
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
except SystemExit:
pass
output = fake_out.getvalue() output = fake_out.getvalue()
mock_publish.assert_called_once() mock_publish.assert_called_once()
self.assertIn("Failed to complete operation", output) assert "Failed to complete operation" in output
self.assertIn("Name is already taken", output) assert "Name is already taken" in output
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch( @patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
"crewai.cli.tools.main.get_project_description", return_value="A sample tool" @patch("crewai.cli.tools.main.subprocess.run")
) @patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
@patch("crewai.cli.tools.main.subprocess.run") @patch(
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open", "crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open, new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content", read_data=b"sample tarball content",
) )
@patch("crewai.cli.plus_api.PlusAPI.publish_tool") @patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_api_error( def test_publish_api_error(
self,
mock_publish, mock_publish,
mock_open, mock_open,
mock_listdir, mock_listdir,
@@ -231,7 +273,7 @@ class TestToolCommand(unittest.TestCase):
mock_get_project_description, mock_get_project_description,
mock_get_project_version, mock_get_project_version,
mock_get_project_name, mock_get_project_name,
): ):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 500 mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"} mock_response.json.return_value = {"error": "Internal Server Error"}
@@ -241,16 +283,18 @@ class TestToolCommand(unittest.TestCase):
tool_command = ToolCommand() tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit): try:
tool_command.publish(is_public=True) tool_command.publish(is_public=True)
except SystemExit:
pass
output = fake_out.getvalue() output = fake_out.getvalue()
mock_publish.assert_called_once() mock_publish.assert_called_once()
self.assertIn("Request to Enterprise API failed", output) assert "Request to Enterprise API failed" in output
@patch("crewai.cli.plus_api.PlusAPI.login_to_tool_repository") @patch("crewai.cli.plus_api.PlusAPI.login_to_tool_repository")
@patch("crewai.cli.tools.main.subprocess.run") @patch("crewai.cli.tools.main.subprocess.run")
def test_login_success(self, mock_subprocess_run, mock_login): def test_login_success(mock_subprocess_run, mock_login):
mock_login_response = MagicMock() mock_login_response = MagicMock()
mock_login_response.status_code = 200 mock_login_response.status_code = 200
mock_login_response.json.return_value = { mock_login_response.json.return_value = {
@@ -297,4 +341,4 @@ class TestToolCommand(unittest.TestCase):
text=True, text=True,
check=True, check=True,
) )
self.assertIn("Succesfully authenticated to the tool repository", output) assert "Succesfully authenticated to the tool repository" in output