From 35ef69d0e8afbeed01bca6e86bfd378af84fce6a Mon Sep 17 00:00:00 2001 From: Vini Brasil Date: Thu, 26 Sep 2024 17:23:31 -0300 Subject: [PATCH] CLI for Tool Repository (#1357) This commit adds two commands to the CLI: - `crewai tool publish` - Builds the project using Poetry - Uploads the tarball to CrewAI's tool repository - `crewai tool install my-tool` - Adds my-tool's index to Poetry and its credentials - Installs my-tool from the custom index --- src/crewai/cli/cli.py | 22 ++ src/crewai/cli/command.py | 40 +++ src/crewai/cli/deploy/api.py | 56 ----- src/crewai/cli/deploy/main.py | 70 ++---- src/crewai/cli/plus_api.py | 68 ++++++ src/crewai/cli/tools/api.py | 26 -- src/crewai/cli/tools/main.py | 168 +++++++++++++ src/crewai/cli/utils.py | 63 ++++- tests/cli/deploy/test_deploy_main.py | 24 +- .../{deploy/test_api.py => test_plus_api.py} | 107 ++++++-- tests/cli/tools/test_api.py | 69 ------ tests/cli/tools/test_main.py | 229 ++++++++++++++++++ 12 files changed, 696 insertions(+), 246 deletions(-) create mode 100644 src/crewai/cli/command.py delete mode 100644 src/crewai/cli/deploy/api.py delete mode 100644 src/crewai/cli/tools/api.py create mode 100644 src/crewai/cli/tools/main.py rename tests/cli/{deploy/test_api.py => test_plus_api.py} (52%) delete mode 100644 tests/cli/tools/test_api.py create mode 100644 tests/cli/tools/test_main.py diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 0b1311946..b54b97c12 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -11,6 +11,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import ( from .authentication.main import AuthenticationCommand from .deploy.main import DeployCommand +from .tools.main import ToolCommand from .evaluate_crew import evaluate_crew from .install_crew import install_crew from .replay_from_task import replay_task_command @@ -202,6 +203,12 @@ def deploy(): pass +@crewai.group() +def tool(): + """Tool Repository related commands.""" + pass + + @deploy.command(name="create") @click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt") def deploy_create(yes: bool): @@ -249,5 +256,20 @@ def deploy_remove(uuid: Optional[str]): deploy_cmd.remove_crew(uuid=uuid) +@tool.command(name="install") +@click.argument("handle") +def tool_install(handle: str): + tool_cmd = ToolCommand() + tool_cmd.install(handle) + + +@tool.command(name="publish") +@click.option("--public", "is_public", flag_value=True, default=False) +@click.option("--private", "is_public", flag_value=False) +def tool_publish(is_public: bool): + tool_cmd = ToolCommand() + tool_cmd.publish(is_public) + + if __name__ == "__main__": crewai() diff --git a/src/crewai/cli/command.py b/src/crewai/cli/command.py new file mode 100644 index 000000000..0b12b9082 --- /dev/null +++ b/src/crewai/cli/command.py @@ -0,0 +1,40 @@ +from typing import Dict, Any +from rich.console import Console +from crewai.cli.plus_api import PlusAPI +from crewai.cli.utils import get_auth_token +from crewai.telemetry.telemetry import Telemetry + +console = Console() + + +class BaseCommand: + def __init__(self): + self._telemetry = Telemetry() + self._telemetry.set_tracer() + + +class PlusAPIMixin: + def __init__(self, telemetry): + try: + telemetry.set_tracer() + self.plus_api_client = PlusAPI(api_key=get_auth_token()) + except Exception: + self._deploy_signup_error_span = telemetry.deploy_signup_error_span() + console.print( + "Please sign up/login to CrewAI+ before using the CLI.", + style="bold red", + ) + console.print("Run 'crewai signup' to sign up/login.", style="bold green") + raise SystemExit + + def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None: + """ + Handle and display error messages from API responses. + + Args: + json_response (Dict[str, Any]): The JSON response containing error information. + """ + error = json_response.get("error", "Unknown error") + message = json_response.get("message", "No message provided") + console.print(f"Error: {error}", style="bold red") + console.print(f"Message: {message}", style="bold red") diff --git a/src/crewai/cli/deploy/api.py b/src/crewai/cli/deploy/api.py deleted file mode 100644 index 4a8954cdc..000000000 --- a/src/crewai/cli/deploy/api.py +++ /dev/null @@ -1,56 +0,0 @@ -import requests -from crewai.cli.plus_api import PlusAPI - - -class CrewAPI(PlusAPI): - """ - CrewAPI class to interact with the Crew resource in CrewAI+ API. - """ - - RESOURCE = "/crewai_plus/api/v1/crews" - - # Deploy - def deploy_by_name(self, project_name: str) -> requests.Response: - return self._make_request( - "POST", f"{self.RESOURCE}/by-name/{project_name}/deploy" - ) - - def deploy_by_uuid(self, uuid: str) -> requests.Response: - return self._make_request("POST", f"{self.RESOURCE}/{uuid}/deploy") - - # Status - def status_by_name(self, project_name: str) -> requests.Response: - return self._make_request( - "GET", f"{self.RESOURCE}/by-name/{project_name}/status" - ) - - def status_by_uuid(self, uuid: str) -> requests.Response: - return self._make_request("GET", f"{self.RESOURCE}/{uuid}/status") - - # Logs - def logs_by_name( - self, project_name: str, log_type: str = "deployment" - ) -> requests.Response: - return self._make_request( - "GET", f"{self.RESOURCE}/by-name/{project_name}/logs/{log_type}" - ) - - def logs_by_uuid( - self, uuid: str, log_type: str = "deployment" - ) -> requests.Response: - return self._make_request("GET", f"{self.RESOURCE}/{uuid}/logs/{log_type}") - - # Delete - def delete_by_name(self, project_name: str) -> requests.Response: - return self._make_request("DELETE", f"{self.RESOURCE}/by-name/{project_name}") - - def delete_by_uuid(self, uuid: str) -> requests.Response: - return self._make_request("DELETE", f"{self.RESOURCE}/{uuid}") - - # List - def list_crews(self) -> requests.Response: - return self._make_request("GET", self.RESOURCE) - - # Create - def create_crew(self, payload) -> requests.Response: - return self._make_request("POST", self.RESOURCE, json=payload) diff --git a/src/crewai/cli/deploy/main.py b/src/crewai/cli/deploy/main.py index 18ed26c04..d6c9d8fe6 100644 --- a/src/crewai/cli/deploy/main.py +++ b/src/crewai/cli/deploy/main.py @@ -2,19 +2,17 @@ from typing import Any, Dict, List, Optional from rich.console import Console -from crewai.telemetry import Telemetry +from crewai.cli.command import BaseCommand, PlusAPIMixin from crewai.cli.utils import ( fetch_and_json_env_file, - get_auth_token, get_git_remote_url, get_project_name, ) -from .api import CrewAPI console = Console() -class DeployCommand: +class DeployCommand(BaseCommand, PlusAPIMixin): """ A class to handle deployment-related operations for CrewAI projects. """ @@ -23,40 +21,10 @@ class DeployCommand: """ Initialize the DeployCommand with project name and API client. """ - try: - self._telemetry = Telemetry() - self._telemetry.set_tracer() - access_token = get_auth_token() - except Exception: - self._deploy_signup_error_span = self._telemetry.deploy_signup_error_span() - console.print( - "Please sign up/login to CrewAI+ before using the CLI.", - style="bold red", - ) - console.print("Run 'crewai signup' to sign up/login.", style="bold green") - raise SystemExit - self.project_name = get_project_name() - if self.project_name is None: - console.print( - "No project name found. Please ensure your project has a valid pyproject.toml file.", - style="bold red", - ) - raise SystemExit - - self.client = CrewAPI(api_key=access_token) - - def _handle_error(self, json_response: Dict[str, Any]) -> None: - """ - Handle and display error messages from API responses. - - Args: - json_response (Dict[str, Any]): The JSON response containing error information. - """ - error = json_response.get("error", "Unknown error") - message = json_response.get("message", "No message provided") - console.print(f"Error: {error}", style="bold red") - console.print(f"Message: {message}", style="bold red") + BaseCommand.__init__(self) + PlusAPIMixin.__init__(self, telemetry=self._telemetry) + self.project_name = get_project_name(require=True) def _standard_no_param_error_message(self) -> None: """ @@ -104,9 +72,9 @@ class DeployCommand: self._start_deployment_span = self._telemetry.start_deployment_span(uuid) console.print("Starting deployment...", style="bold blue") if uuid: - response = self.client.deploy_by_uuid(uuid) + response = self.plus_api_client.deploy_by_uuid(uuid) elif self.project_name: - response = self.client.deploy_by_name(self.project_name) + response = self.plus_api_client.deploy_by_name(self.project_name) else: self._standard_no_param_error_message() return @@ -115,7 +83,7 @@ class DeployCommand: if response.status_code == 200: self._display_deployment_info(json_response) else: - self._handle_error(json_response) + self._handle_plus_api_error(json_response) def create_crew(self, confirm: bool = False) -> None: """ @@ -139,11 +107,11 @@ class DeployCommand: self._confirm_input(env_vars, remote_repo_url, confirm) payload = self._create_payload(env_vars, remote_repo_url) - response = self.client.create_crew(payload) + response = self.plus_api_client.create_crew(payload) if response.status_code == 201: self._display_creation_success(response.json()) else: - self._handle_error(response.json()) + self._handle_plus_api_error(response.json()) def _confirm_input( self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool @@ -208,7 +176,7 @@ class DeployCommand: """ console.print("Listing all Crews\n", style="bold blue") - response = self.client.list_crews() + response = self.plus_api_client.list_crews() json_response = response.json() if response.status_code == 200: self._display_crews(json_response) @@ -243,9 +211,9 @@ class DeployCommand: """ console.print("Fetching deployment status...", style="bold blue") if uuid: - response = self.client.status_by_uuid(uuid) + response = self.plus_api_client.crew_status_by_uuid(uuid) elif self.project_name: - response = self.client.status_by_name(self.project_name) + response = self.plus_api_client.crew_status_by_name(self.project_name) else: self._standard_no_param_error_message() return @@ -254,7 +222,7 @@ class DeployCommand: if response.status_code == 200: self._display_crew_status(json_response) else: - self._handle_error(json_response) + self._handle_plus_api_error(json_response) def _display_crew_status(self, status_data: Dict[str, str]) -> None: """ @@ -278,9 +246,9 @@ class DeployCommand: console.print(f"Fetching {log_type} logs...", style="bold blue") if uuid: - response = self.client.logs_by_uuid(uuid, log_type) + response = self.plus_api_client.crew_by_uuid(uuid, log_type) elif self.project_name: - response = self.client.logs_by_name(self.project_name, log_type) + response = self.plus_api_client.crew_by_name(self.project_name, log_type) else: self._standard_no_param_error_message() return @@ -288,7 +256,7 @@ class DeployCommand: if response.status_code == 200: self._display_logs(response.json()) else: - self._handle_error(response.json()) + self._handle_plus_api_error(response.json()) def remove_crew(self, uuid: Optional[str]) -> None: """ @@ -301,9 +269,9 @@ class DeployCommand: console.print("Removing deployment...", style="bold blue") if uuid: - response = self.client.delete_by_uuid(uuid) + response = self.plus_api_client.delete_crew_by_uuid(uuid) elif self.project_name: - response = self.client.delete_by_name(self.project_name) + response = self.plus_api_client.delete_crew_by_name(self.project_name) else: self._standard_no_param_error_message() return diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index e829f1199..e72d27bfe 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -1,3 +1,4 @@ +from typing import Optional import requests from os import getenv from crewai.cli.utils import get_crewai_version @@ -9,6 +10,9 @@ class PlusAPI: This class exposes methods for working with the CrewAI+ API. """ + TOOLS_RESOURCE = "/crewai_plus/api/v1/tools" + CREWS_RESOURCE = "/crewai_plus/api/v1/crews" + def __init__(self, api_key: str) -> None: self.api_key = api_key self.headers = { @@ -22,3 +26,67 @@ class PlusAPI: def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: url = urljoin(self.base_url, endpoint) return requests.request(method, url, headers=self.headers, **kwargs) + + def get_tool(self, handle: str): + return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}") + + def publish_tool( + self, + handle: str, + is_public: bool, + version: str, + description: Optional[str], + encoded_file: str, + ): + params = { + "handle": handle, + "public": is_public, + "version": version, + "file": encoded_file, + "description": description, + } + return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params) + + def deploy_by_name(self, project_name: str) -> requests.Response: + return self._make_request( + "POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy" + ) + + def deploy_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy") + + def crew_status_by_name(self, project_name: str) -> requests.Response: + return self._make_request( + "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status" + ) + + def crew_status_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status") + + def crew_by_name( + self, project_name: str, log_type: str = "deployment" + ) -> requests.Response: + return self._make_request( + "GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}" + ) + + def crew_by_uuid( + self, uuid: str, log_type: str = "deployment" + ) -> requests.Response: + return self._make_request( + "GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}" + ) + + def delete_crew_by_name(self, project_name: str) -> requests.Response: + return self._make_request( + "DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}" + ) + + def delete_crew_by_uuid(self, uuid: str) -> requests.Response: + return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}") + + def list_crews(self) -> requests.Response: + return self._make_request("GET", self.CREWS_RESOURCE) + + def create_crew(self, payload) -> requests.Response: + return self._make_request("POST", self.CREWS_RESOURCE, json=payload) diff --git a/src/crewai/cli/tools/api.py b/src/crewai/cli/tools/api.py deleted file mode 100644 index c93fdb9c0..000000000 --- a/src/crewai/cli/tools/api.py +++ /dev/null @@ -1,26 +0,0 @@ -from typing import Optional -from crewai.cli.plus_api import PlusAPI - - -class ToolsAPI(PlusAPI): - RESOURCE = "/crewai_plus/api/v1/tools" - - def get(self, handle: str): - return self._make_request("GET", f"{self.RESOURCE}/{handle}") - - def publish( - self, - handle: str, - public: bool, - version: str, - description: Optional[str], - encoded_file: str, - ): - params = { - "handle": handle, - "public": public, - "version": version, - "file": encoded_file, - "description": description, - } - return self._make_request("POST", f"{self.RESOURCE}", json=params) diff --git a/src/crewai/cli/tools/main.py b/src/crewai/cli/tools/main.py new file mode 100644 index 000000000..8acbcedd5 --- /dev/null +++ b/src/crewai/cli/tools/main.py @@ -0,0 +1,168 @@ +import base64 +import click +import os +import subprocess +import tempfile + +from crewai.cli.command import BaseCommand, PlusAPIMixin +from crewai.cli.utils import ( + get_project_name, + get_project_description, + get_project_version, +) +from rich.console import Console + +console = Console() + + +class ToolCommand(BaseCommand, PlusAPIMixin): + """ + A class to handle tool repository related operations for CrewAI projects. + """ + + def __init__(self): + BaseCommand.__init__(self) + PlusAPIMixin.__init__(self, telemetry=self._telemetry) + + def publish(self, is_public: bool): + project_name = get_project_name(require=True) + assert isinstance(project_name, str) + + project_version = get_project_version(require=True) + assert isinstance(project_version, str) + + project_description = get_project_description(require=False) + encoded_tarball = None + + with tempfile.TemporaryDirectory() as temp_build_dir: + subprocess.run( + ["poetry", "build", "-f", "sdist", "--output", temp_build_dir], + check=True, + capture_output=False, + ) + + tarball_filename = next( + (f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None + ) + if not tarball_filename: + console.print( + "Project build failed. Please ensure that the command `poetry build -f sdist` completes successfully.", + style="bold red", + ) + raise SystemExit + + tarball_path = os.path.join(temp_build_dir, tarball_filename) + with open(tarball_path, "rb") as file: + tarball_contents = file.read() + + encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8") + + publish_response = self.plus_api_client.publish_tool( + handle=project_name, + is_public=is_public, + version=project_version, + description=project_description, + encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}", + ) + if publish_response.status_code == 422: + console.print( + "[bold red]Failed to publish tool. Please fix the following errors:[/bold red]" + ) + for field, messages in publish_response.json().items(): + for message in messages: + console.print( + f"* [bold red]{field.capitalize()}[/bold red] {message}" + ) + + raise SystemExit + elif publish_response.status_code != 200: + self._handle_plus_api_error(publish_response.json()) + console.print( + "Failed to publish tool. Please try again later.", style="bold red" + ) + raise SystemExit + + published_handle = publish_response.json()["handle"] + console.print( + f"Succesfully published {published_handle} ({project_version}).\nInstall it in other projects with crewai tool install {published_handle}", + style="bold green", + ) + + def install(self, handle: str): + get_response = self.plus_api_client.get_tool(handle) + + if get_response.status_code == 404: + console.print( + "No tool found with this name. Please ensure the tool was published and you have access to it.", + style="bold red", + ) + raise SystemExit + elif get_response.status_code != 200: + console.print( + "Failed to get tool details. Please try again later.", style="bold red" + ) + raise SystemExit + + self._add_repository_to_poetry(get_response.json()) + self._add_package(get_response.json()) + + console.print(f"Succesfully installed {handle}", style="bold green") + + def _add_repository_to_poetry(self, tool_details): + repository_handle = f"crewai-{tool_details['repository']['handle']}" + repository_url = tool_details["repository"]["url"] + repository_credentials = tool_details["repository"]["credentials"] + + add_repository_command = [ + "poetry", + "source", + "add", + "--priority=explicit", + repository_handle, + repository_url, + ] + add_repository_result = subprocess.run( + add_repository_command, text=True, check=True + ) + + if add_repository_result.stderr: + click.echo(add_repository_result.stderr, err=True) + raise SystemExit + + add_repository_credentials_command = [ + "poetry", + "config", + f"http-basic.{repository_handle}", + repository_credentials, + '""', + ] + add_repository_credentials_result = subprocess.run( + add_repository_credentials_command, + capture_output=False, + text=True, + check=True, + ) + + if add_repository_credentials_result.stderr: + click.echo(add_repository_credentials_result.stderr, err=True) + raise SystemExit + + def _add_package(self, tool_details): + tool_handle = tool_details["handle"] + repository_handle = tool_details["repository"]["handle"] + pypi_index_handle = f"crewai-{repository_handle}" + + add_package_command = [ + "poetry", + "add", + "--source", + pypi_index_handle, + tool_handle, + ] + add_package_result = subprocess.run( + add_package_command, capture_output=False, text=True, check=True + ) + + if add_package_result.stderr: + click.echo(add_package_result.stderr, err=True) + raise SystemExit diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 363cdd0cf..58aa154dd 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -3,8 +3,10 @@ import re import subprocess import sys -from rich.console import Console from crewai.cli.authentication.utils import TokenManager +from functools import reduce +from rich.console import Console +from typing import Any, Dict, List if sys.version_info >= (3, 11): import tomllib @@ -88,21 +90,51 @@ def get_git_remote_url() -> str | None: return None -def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None: +def get_project_name( + pyproject_path: str = "pyproject.toml", require: bool = False +) -> str | None: """Get the project name from the pyproject.toml file.""" + return _get_project_attribute( + pyproject_path, ["tool", "poetry", "name"], require=require + ) + + +def get_project_version( + pyproject_path: str = "pyproject.toml", require: bool = False +) -> str | None: + """Get the project version from the pyproject.toml file.""" + return _get_project_attribute( + pyproject_path, ["tool", "poetry", "version"], require=require + ) + + +def get_project_description( + pyproject_path: str = "pyproject.toml", require: bool = False +) -> str | None: + """Get the project description from the pyproject.toml file.""" + return _get_project_attribute( + pyproject_path, ["tool", "poetry", "description"], require=require + ) + + +def _get_project_attribute( + pyproject_path: str, keys: List[str], require: bool +) -> Any | None: + """Get an attribute from the pyproject.toml file.""" + attribute = None + try: - # Read the pyproject.toml file with open(pyproject_path, "r") as f: pyproject_content = parse_toml(f.read()) - # Extract the project name - project_name = pyproject_content["tool"]["poetry"]["name"] - - if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]: + dependencies = ( + _get_nested_value(pyproject_content, ["tool", "poetry", "dependencies"]) + or {} + ) + if "crewai" not in dependencies: raise Exception("crewai is not in the dependencies.") - return project_name - + attribute = _get_nested_value(pyproject_content, keys) except FileNotFoundError: print(f"Error: {pyproject_path} not found.") except KeyError: @@ -116,7 +148,18 @@ def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None: except Exception as e: print(f"Error reading the pyproject.toml file: {e}") - return None + if require and not attribute: + console.print( + f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.", + style="bold red", + ) + raise SystemExit + + return attribute + + +def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: + return reduce(dict.__getitem__, keys, data) def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str: diff --git a/tests/cli/deploy/test_deploy_main.py b/tests/cli/deploy/test_deploy_main.py index 2e8b62a18..ddb9709d3 100644 --- a/tests/cli/deploy/test_deploy_main.py +++ b/tests/cli/deploy/test_deploy_main.py @@ -8,34 +8,34 @@ from crewai.cli.utils import parse_toml class TestDeployCommand(unittest.TestCase): - @patch("crewai.cli.deploy.main.get_auth_token") + @patch("crewai.cli.command.get_auth_token") @patch("crewai.cli.deploy.main.get_project_name") - @patch("crewai.cli.deploy.main.CrewAPI") - def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token): + @patch("crewai.cli.command.PlusAPI") + def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token): self.mock_get_auth_token = mock_get_auth_token self.mock_get_project_name = mock_get_project_name - self.mock_crew_api = mock_crew_api + self.mock_plus_api = mock_plus_api self.mock_get_auth_token.return_value = "test_token" self.mock_get_project_name.return_value = "test_project" self.deploy_command = DeployCommand() - self.mock_client = self.deploy_command.client + self.mock_client = self.deploy_command.plus_api_client def test_init_success(self): self.assertEqual(self.deploy_command.project_name, "test_project") - self.mock_crew_api.assert_called_once_with(api_key="test_token") + self.mock_plus_api.assert_called_once_with(api_key="test_token") - @patch("crewai.cli.deploy.main.get_auth_token") + @patch("crewai.cli.command.get_auth_token") def test_init_failure(self, mock_get_auth_token): mock_get_auth_token.side_effect = Exception("Auth failed") with self.assertRaises(SystemExit): DeployCommand() - def test_handle_error(self): + def test_handle_plus_api_error(self): with patch("sys.stdout", new=StringIO()) as fake_out: - self.deploy_command._handle_error( + self.deploy_command._handle_plus_api_error( {"error": "Test error", "message": "Test message"} ) self.assertIn("Error: Test error", fake_out.getvalue()) @@ -122,7 +122,7 @@ class TestDeployCommand(unittest.TestCase): mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = {"name": "TestCrew", "status": "active"} - self.mock_client.status_by_name.return_value = mock_response + self.mock_client.crew_status_by_name.return_value = mock_response with patch("sys.stdout", new=StringIO()) as fake_out: self.deploy_command.get_crew_status() @@ -136,7 +136,7 @@ class TestDeployCommand(unittest.TestCase): {"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"}, {"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"}, ] - self.mock_client.logs_by_name.return_value = mock_response + self.mock_client.crew_by_name.return_value = mock_response with patch("sys.stdout", new=StringIO()) as fake_out: self.deploy_command.get_crew_logs(None) @@ -146,7 +146,7 @@ class TestDeployCommand(unittest.TestCase): def test_remove_crew(self): mock_response = MagicMock() mock_response.status_code = 204 - self.mock_client.delete_by_name.return_value = mock_response + self.mock_client.delete_crew_by_name.return_value = mock_response with patch("sys.stdout", new=StringIO()) as fake_out: self.deploy_command.remove_crew(None) diff --git a/tests/cli/deploy/test_api.py b/tests/cli/test_plus_api.py similarity index 52% rename from tests/cli/deploy/test_api.py rename to tests/cli/test_plus_api.py index 77a5ff84d..506246290 100644 --- a/tests/cli/deploy/test_api.py +++ b/tests/cli/test_plus_api.py @@ -1,14 +1,13 @@ +import os import unittest -from os import environ from unittest.mock import MagicMock, patch - -from crewai.cli.deploy.api import CrewAPI +from crewai.cli.plus_api import PlusAPI -class TestCrewAPI(unittest.TestCase): +class TestPlusAPI(unittest.TestCase): def setUp(self): self.api_key = "test_api_key" - self.api = CrewAPI(self.api_key) + self.api = PlusAPI(self.api_key) def test_init(self): self.assertEqual(self.api.api_key, self.api_key) @@ -22,6 +21,70 @@ class TestCrewAPI(unittest.TestCase): }, ) + @patch("crewai.cli.plus_api.PlusAPI._make_request") + def test_get_tool(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + + response = self.api.get_tool("test_tool_handle") + + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/tools/test_tool_handle" + ) + self.assertEqual(response, mock_response) + + @patch("crewai.cli.plus_api.PlusAPI._make_request") + def test_publish_tool(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + handle = "test_tool_handle" + public = True + version = "1.0.0" + description = "Test tool description" + encoded_file = "encoded_test_file" + + response = self.api.publish_tool( + handle, public, version, description, encoded_file + ) + + params = { + "handle": handle, + "public": public, + "version": version, + "file": encoded_file, + "description": description, + } + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/tools", json=params + ) + self.assertEqual(response, mock_response) + + @patch("crewai.cli.plus_api.PlusAPI._make_request") + def test_publish_tool_without_description(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + handle = "test_tool_handle" + public = False + version = "2.0.0" + description = None + encoded_file = "encoded_test_file" + + response = self.api.publish_tool( + handle, public, version, description, encoded_file + ) + + params = { + "handle": handle, + "public": public, + "version": version, + "file": encoded_file, + "description": description, + } + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/tools", json=params + ) + self.assertEqual(response, mock_response) + @patch("crewai.cli.plus_api.requests.request") def test_make_request(self, mock_request): mock_response = MagicMock() @@ -49,53 +112,53 @@ class TestCrewAPI(unittest.TestCase): ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_status_by_name(self, mock_make_request): - self.api.status_by_name("test_project") + def test_crew_status_by_name(self, mock_make_request): + self.api.crew_status_by_name("test_project") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/by-name/test_project/status" ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_status_by_uuid(self, mock_make_request): - self.api.status_by_uuid("test_uuid") + def test_crew_status_by_uuid(self, mock_make_request): + self.api.crew_status_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/test_uuid/status" ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_logs_by_name(self, mock_make_request): - self.api.logs_by_name("test_project") + def test_crew_by_name(self, mock_make_request): + self.api.crew_by_name("test_project") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment" ) - self.api.logs_by_name("test_project", "custom_log") + self.api.crew_by_name("test_project", "custom_log") mock_make_request.assert_called_with( "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log" ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_logs_by_uuid(self, mock_make_request): - self.api.logs_by_uuid("test_uuid") + def test_crew_by_uuid(self, mock_make_request): + self.api.crew_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment" ) - self.api.logs_by_uuid("test_uuid", "custom_log") + self.api.crew_by_uuid("test_uuid", "custom_log") mock_make_request.assert_called_with( "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log" ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_delete_by_name(self, mock_make_request): - self.api.delete_by_name("test_project") + def test_delete_crew_by_name(self, mock_make_request): + self.api.delete_crew_by_name("test_project") mock_make_request.assert_called_once_with( "DELETE", "/crewai_plus/api/v1/crews/by-name/test_project" ) @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_delete_by_uuid(self, mock_make_request): - self.api.delete_by_uuid("test_uuid") + def test_delete_crew_by_uuid(self, mock_make_request): + self.api.delete_crew_by_uuid("test_uuid") mock_make_request.assert_called_once_with( "DELETE", "/crewai_plus/api/v1/crews/test_uuid" ) @@ -105,7 +168,7 @@ class TestCrewAPI(unittest.TestCase): self.api.list_crews() mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews") - @patch("crewai.cli.deploy.api.CrewAPI._make_request") + @patch("crewai.cli.plus_api.PlusAPI._make_request") def test_create_crew(self, mock_make_request): payload = {"name": "test_crew"} self.api.create_crew(payload) @@ -113,9 +176,9 @@ class TestCrewAPI(unittest.TestCase): "POST", "/crewai_plus/api/v1/crews", json=payload ) - @patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"}) + @patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"}) def test_custom_base_url(self): - custom_api = CrewAPI("test_key") + custom_api = PlusAPI("test_key") self.assertEqual( custom_api.base_url, "https://custom-url.com/api", diff --git a/tests/cli/tools/test_api.py b/tests/cli/tools/test_api.py deleted file mode 100644 index 17c100678..000000000 --- a/tests/cli/tools/test_api.py +++ /dev/null @@ -1,69 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch -from crewai.cli.tools.api import ToolsAPI - - -class TestToolsAPI(unittest.TestCase): - def setUp(self): - self.api_key = "test_api_key" - self.api = ToolsAPI(self.api_key) - - @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_get_tool(self, mock_make_request): - mock_response = MagicMock() - mock_make_request.return_value = mock_response - - response = self.api.get("test_tool_handle") - - mock_make_request.assert_called_once_with( - "GET", "/crewai_plus/api/v1/tools/test_tool_handle" - ) - self.assertEqual(response, mock_response) - - @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_publish_tool(self, mock_make_request): - mock_response = MagicMock() - mock_make_request.return_value = mock_response - handle = "test_tool_handle" - public = True - version = "1.0.0" - description = "Test tool description" - encoded_file = "encoded_test_file" - - response = self.api.publish(handle, public, version, description, encoded_file) - - params = { - "handle": handle, - "public": public, - "version": version, - "file": encoded_file, - "description": description, - } - mock_make_request.assert_called_once_with( - "POST", "/crewai_plus/api/v1/tools", json=params - ) - self.assertEqual(response, mock_response) - - @patch("crewai.cli.plus_api.PlusAPI._make_request") - def test_publish_tool_without_description(self, mock_make_request): - mock_response = MagicMock() - mock_make_request.return_value = mock_response - handle = "test_tool_handle" - public = False - version = "2.0.0" - description = None - encoded_file = "encoded_test_file" - - response = self.api.publish(handle, public, version, description, encoded_file) - - params = { - "handle": handle, - "public": public, - "version": version, - "file": encoded_file, - "description": description, - } - mock_make_request.assert_called_once_with( - "POST", "/crewai_plus/api/v1/tools", json=params - ) - self.assertEqual(response, mock_response) diff --git a/tests/cli/tools/test_main.py b/tests/cli/tools/test_main.py new file mode 100644 index 000000000..f387c8d3f --- /dev/null +++ b/tests/cli/tools/test_main.py @@ -0,0 +1,229 @@ +import unittest +import unittest.mock +from crewai.cli.tools.main import ToolCommand +from io import StringIO +from unittest.mock import patch, MagicMock + + +class TestToolCommand(unittest.TestCase): + @patch("crewai.cli.tools.main.subprocess.run") + @patch("crewai.cli.plus_api.PlusAPI.get_tool") + def test_install_success(self, mock_get, mock_subprocess_run): + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "handle": "sample-tool", + "repository": { + "handle": "sample-repo", + "url": "https://example.com/repo", + "credentials": "my_very_secret", + }, + } + mock_get.return_value = mock_get_response + mock_subprocess_run.return_value = MagicMock(stderr=None) + + tool_command = ToolCommand() + + with patch("sys.stdout", new=StringIO()) as fake_out: + tool_command.install("sample-tool") + output = fake_out.getvalue() + + mock_get.assert_called_once_with("sample-tool") + mock_subprocess_run.assert_any_call( + [ + "poetry", + "source", + "add", + "--priority=explicit", + "crewai-sample-repo", + "https://example.com/repo", + ], + text=True, + check=True, + ) + mock_subprocess_run.assert_any_call( + [ + "poetry", + "config", + "http-basic.crewai-sample-repo", + "my_very_secret", + '""', + ], + capture_output=False, + text=True, + check=True, + ) + mock_subprocess_run.assert_any_call( + ["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"], + capture_output=False, + text=True, + check=True, + ) + + self.assertIn("Succesfully installed sample-tool", output) + + @patch("crewai.cli.plus_api.PlusAPI.get_tool") + def test_install_tool_not_found(self, mock_get): + mock_get_response = MagicMock() + mock_get_response.status_code = 404 + mock_get.return_value = mock_get_response + + tool_command = ToolCommand() + + with patch("sys.stdout", new=StringIO()) as fake_out: + with self.assertRaises(SystemExit): + tool_command.install("non-existent-tool") + output = fake_out.getvalue() + + mock_get.assert_called_once_with("non-existent-tool") + self.assertIn("No tool found with this name", output) + + @patch("crewai.cli.plus_api.PlusAPI.get_tool") + def test_install_api_error(self, mock_get): + mock_get_response = MagicMock() + mock_get_response.status_code = 500 + mock_get.return_value = mock_get_response + + tool_command = ToolCommand() + + with patch("sys.stdout", new=StringIO()) as fake_out: + with self.assertRaises(SystemExit): + tool_command.install("error-tool") + output = fake_out.getvalue() + + mock_get.assert_called_once_with("error-tool") + self.assertIn("Failed to get tool details", output) + + @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") + @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") + @patch( + "crewai.cli.tools.main.get_project_description", return_value="A sample tool" + ) + @patch("crewai.cli.tools.main.subprocess.run") + @patch( + "crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"] + ) + @patch( + "crewai.cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", + ) + @patch("crewai.cli.plus_api.PlusAPI.publish_tool") + def test_publish_success( + self, + mock_publish, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + ): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 200 + mock_publish_response.json.return_value = {"handle": "sample-tool"} + mock_publish.return_value = mock_publish_response + + tool_command = ToolCommand() + tool_command.publish(is_public=True) + + mock_get_project_name.assert_called_once_with(require=True) + mock_get_project_version.assert_called_once_with(require=True) + mock_get_project_description.assert_called_once_with(require=False) + mock_subprocess_run.assert_called_once_with( + ["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY], + check=True, + capture_output=False, + ) + mock_open.assert_called_once_with(unittest.mock.ANY, "rb") + mock_publish.assert_called_once_with( + handle="sample-tool", + is_public=True, + version="1.0.0", + description="A sample tool", + encoded_file=unittest.mock.ANY, + ) + + @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") + @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") + @patch( + "crewai.cli.tools.main.get_project_description", return_value="A sample tool" + ) + @patch("crewai.cli.tools.main.subprocess.run") + @patch( + "crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"] + ) + @patch( + "crewai.cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", + ) + @patch("crewai.cli.plus_api.PlusAPI.publish_tool") + def test_publish_failure( + self, + mock_publish, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + ): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 422 + mock_publish_response.json.return_value = {"name": ["is already taken"]} + mock_publish.return_value = mock_publish_response + + tool_command = ToolCommand() + + with patch("sys.stdout", new=StringIO()) as fake_out: + with self.assertRaises(SystemExit): + tool_command.publish(is_public=True) + output = fake_out.getvalue() + + mock_publish.assert_called_once() + self.assertIn("Failed to publish tool", output) + self.assertIn("Name is already taken", output) + + @patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool") + @patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0") + @patch( + "crewai.cli.tools.main.get_project_description", return_value="A sample tool" + ) + @patch("crewai.cli.tools.main.subprocess.run") + @patch( + "crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"] + ) + @patch( + "crewai.cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", + ) + @patch("crewai.cli.plus_api.PlusAPI.publish_tool") + def test_publish_api_error( + self, + mock_publish, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + ): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 500 + mock_publish.return_value = mock_publish_response + + tool_command = ToolCommand() + + with patch("sys.stdout", new=StringIO()) as fake_out: + with self.assertRaises(SystemExit): + tool_command.publish(is_public=True) + output = fake_out.getvalue() + + mock_publish.assert_called_once() + self.assertIn("Failed to publish tool", output) + + +if __name__ == "__main__": + unittest.main()