CLI for Tool Repository (#1357)

This commit adds two commands to the CLI:

- `crewai tool publish`
    - Builds the project using Poetry
    - Uploads the tarball to CrewAI's tool repository

- `crewai tool install my-tool`
    - Adds my-tool's index to Poetry and its credentials
    - Installs my-tool from the custom index
This commit is contained in:
Vini Brasil
2024-09-26 17:23:31 -03:00
committed by GitHub
parent 104ef7a0c2
commit c3ac3219fe
12 changed files with 696 additions and 246 deletions

View File

@@ -11,6 +11,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
from .authentication.main import AuthenticationCommand from .authentication.main import AuthenticationCommand
from .deploy.main import DeployCommand from .deploy.main import DeployCommand
from .tools.main import ToolCommand
from .evaluate_crew import evaluate_crew from .evaluate_crew import evaluate_crew
from .install_crew import install_crew from .install_crew import install_crew
from .replay_from_task import replay_task_command from .replay_from_task import replay_task_command
@@ -202,6 +203,12 @@ def deploy():
pass pass
@crewai.group()
def tool():
"""Tool Repository related commands."""
pass
@deploy.command(name="create") @deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt") @click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool): def deploy_create(yes: bool):
@@ -249,5 +256,20 @@ def deploy_remove(uuid: Optional[str]):
deploy_cmd.remove_crew(uuid=uuid) deploy_cmd.remove_crew(uuid=uuid)
@tool.command(name="install")
@click.argument("handle")
def tool_install(handle: str):
tool_cmd = ToolCommand()
tool_cmd.install(handle)
@tool.command(name="publish")
@click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool):
tool_cmd = ToolCommand()
tool_cmd.publish(is_public)
if __name__ == "__main__": if __name__ == "__main__":
crewai() crewai()

40
src/crewai/cli/command.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Dict, Any
from rich.console import Console
from crewai.cli.plus_api import PlusAPI
from crewai.cli.utils import get_auth_token
from crewai.telemetry.telemetry import Telemetry
console = Console()
class BaseCommand:
def __init__(self):
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry):
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
"""
Handle and display error messages from API responses.
Args:
json_response (Dict[str, Any]): The JSON response containing error information.
"""
error = json_response.get("error", "Unknown error")
message = json_response.get("message", "No message provided")
console.print(f"Error: {error}", style="bold red")
console.print(f"Message: {message}", style="bold red")

View File

@@ -1,56 +0,0 @@
import requests
from crewai.cli.plus_api import PlusAPI
class CrewAPI(PlusAPI):
"""
CrewAPI class to interact with the Crew resource in CrewAI+ API.
"""
RESOURCE = "/crewai_plus/api/v1/crews"
# Deploy
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("POST", f"{self.RESOURCE}/{uuid}/deploy")
# Status
def status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.RESOURCE}/by-name/{project_name}/status"
)
def status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/status")
# Logs
def logs_by_name(
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def logs_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/logs/{log_type}")
# Delete
def delete_by_name(self, project_name: str) -> requests.Response:
return self._make_request("DELETE", f"{self.RESOURCE}/by-name/{project_name}")
def delete_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("DELETE", f"{self.RESOURCE}/{uuid}")
# List
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.RESOURCE)
# Create
def create_crew(self, payload) -> requests.Response:
return self._make_request("POST", self.RESOURCE, json=payload)

View File

@@ -2,19 +2,17 @@ from typing import Any, Dict, List, Optional
from rich.console import Console from rich.console import Console
from crewai.telemetry import Telemetry from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import ( from crewai.cli.utils import (
fetch_and_json_env_file, fetch_and_json_env_file,
get_auth_token,
get_git_remote_url, get_git_remote_url,
get_project_name, get_project_name,
) )
from .api import CrewAPI
console = Console() console = Console()
class DeployCommand: class DeployCommand(BaseCommand, PlusAPIMixin):
""" """
A class to handle deployment-related operations for CrewAI projects. A class to handle deployment-related operations for CrewAI projects.
""" """
@@ -23,40 +21,10 @@ class DeployCommand:
""" """
Initialize the DeployCommand with project name and API client. Initialize the DeployCommand with project name and API client.
""" """
try:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
access_token = get_auth_token()
except Exception:
self._deploy_signup_error_span = self._telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
self.project_name = get_project_name() BaseCommand.__init__(self)
if self.project_name is None: PlusAPIMixin.__init__(self, telemetry=self._telemetry)
console.print( self.project_name = get_project_name(require=True)
"No project name found. Please ensure your project has a valid pyproject.toml file.",
style="bold red",
)
raise SystemExit
self.client = CrewAPI(api_key=access_token)
def _handle_error(self, json_response: Dict[str, Any]) -> None:
"""
Handle and display error messages from API responses.
Args:
json_response (Dict[str, Any]): The JSON response containing error information.
"""
error = json_response.get("error", "Unknown error")
message = json_response.get("message", "No message provided")
console.print(f"Error: {error}", style="bold red")
console.print(f"Message: {message}", style="bold red")
def _standard_no_param_error_message(self) -> None: def _standard_no_param_error_message(self) -> None:
""" """
@@ -104,9 +72,9 @@ class DeployCommand:
self._start_deployment_span = self._telemetry.start_deployment_span(uuid) self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue") console.print("Starting deployment...", style="bold blue")
if uuid: if uuid:
response = self.client.deploy_by_uuid(uuid) response = self.plus_api_client.deploy_by_uuid(uuid)
elif self.project_name: elif self.project_name:
response = self.client.deploy_by_name(self.project_name) response = self.plus_api_client.deploy_by_name(self.project_name)
else: else:
self._standard_no_param_error_message() self._standard_no_param_error_message()
return return
@@ -115,7 +83,7 @@ class DeployCommand:
if response.status_code == 200: if response.status_code == 200:
self._display_deployment_info(json_response) self._display_deployment_info(json_response)
else: else:
self._handle_error(json_response) self._handle_plus_api_error(json_response)
def create_crew(self, confirm: bool = False) -> None: def create_crew(self, confirm: bool = False) -> None:
""" """
@@ -139,11 +107,11 @@ class DeployCommand:
self._confirm_input(env_vars, remote_repo_url, confirm) self._confirm_input(env_vars, remote_repo_url, confirm)
payload = self._create_payload(env_vars, remote_repo_url) payload = self._create_payload(env_vars, remote_repo_url)
response = self.client.create_crew(payload) response = self.plus_api_client.create_crew(payload)
if response.status_code == 201: if response.status_code == 201:
self._display_creation_success(response.json()) self._display_creation_success(response.json())
else: else:
self._handle_error(response.json()) self._handle_plus_api_error(response.json())
def _confirm_input( def _confirm_input(
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
@@ -208,7 +176,7 @@ class DeployCommand:
""" """
console.print("Listing all Crews\n", style="bold blue") console.print("Listing all Crews\n", style="bold blue")
response = self.client.list_crews() response = self.plus_api_client.list_crews()
json_response = response.json() json_response = response.json()
if response.status_code == 200: if response.status_code == 200:
self._display_crews(json_response) self._display_crews(json_response)
@@ -243,9 +211,9 @@ class DeployCommand:
""" """
console.print("Fetching deployment status...", style="bold blue") console.print("Fetching deployment status...", style="bold blue")
if uuid: if uuid:
response = self.client.status_by_uuid(uuid) response = self.plus_api_client.crew_status_by_uuid(uuid)
elif self.project_name: elif self.project_name:
response = self.client.status_by_name(self.project_name) response = self.plus_api_client.crew_status_by_name(self.project_name)
else: else:
self._standard_no_param_error_message() self._standard_no_param_error_message()
return return
@@ -254,7 +222,7 @@ class DeployCommand:
if response.status_code == 200: if response.status_code == 200:
self._display_crew_status(json_response) self._display_crew_status(json_response)
else: else:
self._handle_error(json_response) self._handle_plus_api_error(json_response)
def _display_crew_status(self, status_data: Dict[str, str]) -> None: def _display_crew_status(self, status_data: Dict[str, str]) -> None:
""" """
@@ -278,9 +246,9 @@ class DeployCommand:
console.print(f"Fetching {log_type} logs...", style="bold blue") console.print(f"Fetching {log_type} logs...", style="bold blue")
if uuid: if uuid:
response = self.client.logs_by_uuid(uuid, log_type) response = self.plus_api_client.crew_by_uuid(uuid, log_type)
elif self.project_name: elif self.project_name:
response = self.client.logs_by_name(self.project_name, log_type) response = self.plus_api_client.crew_by_name(self.project_name, log_type)
else: else:
self._standard_no_param_error_message() self._standard_no_param_error_message()
return return
@@ -288,7 +256,7 @@ class DeployCommand:
if response.status_code == 200: if response.status_code == 200:
self._display_logs(response.json()) self._display_logs(response.json())
else: else:
self._handle_error(response.json()) self._handle_plus_api_error(response.json())
def remove_crew(self, uuid: Optional[str]) -> None: def remove_crew(self, uuid: Optional[str]) -> None:
""" """
@@ -301,9 +269,9 @@ class DeployCommand:
console.print("Removing deployment...", style="bold blue") console.print("Removing deployment...", style="bold blue")
if uuid: if uuid:
response = self.client.delete_by_uuid(uuid) response = self.plus_api_client.delete_crew_by_uuid(uuid)
elif self.project_name: elif self.project_name:
response = self.client.delete_by_name(self.project_name) response = self.plus_api_client.delete_crew_by_name(self.project_name)
else: else:
self._standard_no_param_error_message() self._standard_no_param_error_message()
return return

View File

@@ -1,3 +1,4 @@
from typing import Optional
import requests import requests
from os import getenv from os import getenv
from crewai.cli.utils import get_crewai_version from crewai.cli.utils import get_crewai_version
@@ -9,6 +10,9 @@ class PlusAPI:
This class exposes methods for working with the CrewAI+ API. This class exposes methods for working with the CrewAI+ API.
""" """
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
def __init__(self, api_key: str) -> None: def __init__(self, api_key: str) -> None:
self.api_key = api_key self.api_key = api_key
self.headers = { self.headers = {
@@ -22,3 +26,67 @@ class PlusAPI:
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response: def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
url = urljoin(self.base_url, endpoint) url = urljoin(self.base_url, endpoint)
return requests.request(method, url, headers=self.headers, **kwargs) return requests.request(method, url, headers=self.headers, **kwargs)
def get_tool(self, handle: str):
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: Optional[str],
encoded_file: str,
):
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload) -> requests.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)

View File

@@ -1,26 +0,0 @@
from typing import Optional
from crewai.cli.plus_api import PlusAPI
class ToolsAPI(PlusAPI):
RESOURCE = "/crewai_plus/api/v1/tools"
def get(self, handle: str):
return self._make_request("GET", f"{self.RESOURCE}/{handle}")
def publish(
self,
handle: str,
public: bool,
version: str,
description: Optional[str],
encoded_file: str,
):
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
return self._make_request("POST", f"{self.RESOURCE}", json=params)

View File

@@ -0,0 +1,168 @@
import base64
import click
import os
import subprocess
import tempfile
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import (
get_project_name,
get_project_description,
get_project_version,
)
from rich.console import Console
console = Console()
class ToolCommand(BaseCommand, PlusAPIMixin):
"""
A class to handle tool repository related operations for CrewAI projects.
"""
def __init__(self):
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def publish(self, is_public: bool):
project_name = get_project_name(require=True)
assert isinstance(project_name, str)
project_version = get_project_version(require=True)
assert isinstance(project_version, str)
project_description = get_project_description(require=False)
encoded_tarball = None
with tempfile.TemporaryDirectory() as temp_build_dir:
subprocess.run(
["poetry", "build", "-f", "sdist", "--output", temp_build_dir],
check=True,
capture_output=False,
)
tarball_filename = next(
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
)
if not tarball_filename:
console.print(
"Project build failed. Please ensure that the command `poetry build -f sdist` completes successfully.",
style="bold red",
)
raise SystemExit
tarball_path = os.path.join(temp_build_dir, tarball_filename)
with open(tarball_path, "rb") as file:
tarball_contents = file.read()
encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8")
publish_response = self.plus_api_client.publish_tool(
handle=project_name,
is_public=is_public,
version=project_version,
description=project_description,
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
)
if publish_response.status_code == 422:
console.print(
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
)
for field, messages in publish_response.json().items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
elif publish_response.status_code != 200:
self._handle_plus_api_error(publish_response.json())
console.print(
"Failed to publish tool. Please try again later.", style="bold red"
)
raise SystemExit
published_handle = publish_response.json()["handle"]
console.print(
f"Succesfully published {published_handle} ({project_version}).\nInstall it in other projects with crewai tool install {published_handle}",
style="bold green",
)
def install(self, handle: str):
get_response = self.plus_api_client.get_tool(handle)
if get_response.status_code == 404:
console.print(
"No tool found with this name. Please ensure the tool was published and you have access to it.",
style="bold red",
)
raise SystemExit
elif get_response.status_code != 200:
console.print(
"Failed to get tool details. Please try again later.", style="bold red"
)
raise SystemExit
self._add_repository_to_poetry(get_response.json())
self._add_package(get_response.json())
console.print(f"Succesfully installed {handle}", style="bold green")
def _add_repository_to_poetry(self, tool_details):
repository_handle = f"crewai-{tool_details['repository']['handle']}"
repository_url = tool_details["repository"]["url"]
repository_credentials = tool_details["repository"]["credentials"]
add_repository_command = [
"poetry",
"source",
"add",
"--priority=explicit",
repository_handle,
repository_url,
]
add_repository_result = subprocess.run(
add_repository_command, text=True, check=True
)
if add_repository_result.stderr:
click.echo(add_repository_result.stderr, err=True)
raise SystemExit
add_repository_credentials_command = [
"poetry",
"config",
f"http-basic.{repository_handle}",
repository_credentials,
'""',
]
add_repository_credentials_result = subprocess.run(
add_repository_credentials_command,
capture_output=False,
text=True,
check=True,
)
if add_repository_credentials_result.stderr:
click.echo(add_repository_credentials_result.stderr, err=True)
raise SystemExit
def _add_package(self, tool_details):
tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"]
pypi_index_handle = f"crewai-{repository_handle}"
add_package_command = [
"poetry",
"add",
"--source",
pypi_index_handle,
tool_handle,
]
add_package_result = subprocess.run(
add_package_command, capture_output=False, text=True, check=True
)
if add_package_result.stderr:
click.echo(add_package_result.stderr, err=True)
raise SystemExit

View File

@@ -3,8 +3,10 @@ import re
import subprocess import subprocess
import sys import sys
from rich.console import Console
from crewai.cli.authentication.utils import TokenManager from crewai.cli.authentication.utils import TokenManager
from functools import reduce
from rich.console import Console
from typing import Any, Dict, List
if sys.version_info >= (3, 11): if sys.version_info >= (3, 11):
import tomllib import tomllib
@@ -88,21 +90,51 @@ def get_git_remote_url() -> str | None:
return None return None
def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None: def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file.""" """Get the project name from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "name"], require=require
)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try: try:
# Read the pyproject.toml file
with open(pyproject_path, "r") as f: with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read()) pyproject_content = parse_toml(f.read())
# Extract the project name dependencies = (
project_name = pyproject_content["tool"]["poetry"]["name"] _get_nested_value(pyproject_content, ["tool", "poetry", "dependencies"])
or {}
if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]: )
if "crewai" not in dependencies:
raise Exception("crewai is not in the dependencies.") raise Exception("crewai is not in the dependencies.")
return project_name attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError: except FileNotFoundError:
print(f"Error: {pyproject_path} not found.") print(f"Error: {pyproject_path} not found.")
except KeyError: except KeyError:
@@ -116,7 +148,18 @@ def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
except Exception as e: except Exception as e:
print(f"Error reading the pyproject.toml file: {e}") print(f"Error reading the pyproject.toml file: {e}")
return None if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str: def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:

View File

@@ -8,34 +8,34 @@ from crewai.cli.utils import parse_toml
class TestDeployCommand(unittest.TestCase): class TestDeployCommand(unittest.TestCase):
@patch("crewai.cli.deploy.main.get_auth_token") @patch("crewai.cli.command.get_auth_token")
@patch("crewai.cli.deploy.main.get_project_name") @patch("crewai.cli.deploy.main.get_project_name")
@patch("crewai.cli.deploy.main.CrewAPI") @patch("crewai.cli.command.PlusAPI")
def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token): def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
self.mock_get_auth_token = mock_get_auth_token self.mock_get_auth_token = mock_get_auth_token
self.mock_get_project_name = mock_get_project_name self.mock_get_project_name = mock_get_project_name
self.mock_crew_api = mock_crew_api self.mock_plus_api = mock_plus_api
self.mock_get_auth_token.return_value = "test_token" self.mock_get_auth_token.return_value = "test_token"
self.mock_get_project_name.return_value = "test_project" self.mock_get_project_name.return_value = "test_project"
self.deploy_command = DeployCommand() self.deploy_command = DeployCommand()
self.mock_client = self.deploy_command.client self.mock_client = self.deploy_command.plus_api_client
def test_init_success(self): def test_init_success(self):
self.assertEqual(self.deploy_command.project_name, "test_project") self.assertEqual(self.deploy_command.project_name, "test_project")
self.mock_crew_api.assert_called_once_with(api_key="test_token") self.mock_plus_api.assert_called_once_with(api_key="test_token")
@patch("crewai.cli.deploy.main.get_auth_token") @patch("crewai.cli.command.get_auth_token")
def test_init_failure(self, mock_get_auth_token): def test_init_failure(self, mock_get_auth_token):
mock_get_auth_token.side_effect = Exception("Auth failed") mock_get_auth_token.side_effect = Exception("Auth failed")
with self.assertRaises(SystemExit): with self.assertRaises(SystemExit):
DeployCommand() DeployCommand()
def test_handle_error(self): def test_handle_plus_api_error(self):
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._handle_error( self.deploy_command._handle_plus_api_error(
{"error": "Test error", "message": "Test message"} {"error": "Test error", "message": "Test message"}
) )
self.assertIn("Error: Test error", fake_out.getvalue()) self.assertIn("Error: Test error", fake_out.getvalue())
@@ -122,7 +122,7 @@ class TestDeployCommand(unittest.TestCase):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 200 mock_response.status_code = 200
mock_response.json.return_value = {"name": "TestCrew", "status": "active"} mock_response.json.return_value = {"name": "TestCrew", "status": "active"}
self.mock_client.status_by_name.return_value = mock_response self.mock_client.crew_status_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_status() self.deploy_command.get_crew_status()
@@ -136,7 +136,7 @@ class TestDeployCommand(unittest.TestCase):
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"}, {"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"}, {"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
] ]
self.mock_client.logs_by_name.return_value = mock_response self.mock_client.crew_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_logs(None) self.deploy_command.get_crew_logs(None)
@@ -146,7 +146,7 @@ class TestDeployCommand(unittest.TestCase):
def test_remove_crew(self): def test_remove_crew(self):
mock_response = MagicMock() mock_response = MagicMock()
mock_response.status_code = 204 mock_response.status_code = 204
self.mock_client.delete_by_name.return_value = mock_response self.mock_client.delete_crew_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out: with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.remove_crew(None) self.deploy_command.remove_crew(None)

View File

@@ -1,14 +1,13 @@
import os
import unittest import unittest
from os import environ
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from crewai.cli.plus_api import PlusAPI
from crewai.cli.deploy.api import CrewAPI
class TestCrewAPI(unittest.TestCase): class TestPlusAPI(unittest.TestCase):
def setUp(self): def setUp(self):
self.api_key = "test_api_key" self.api_key = "test_api_key"
self.api = CrewAPI(self.api_key) self.api = PlusAPI(self.api_key)
def test_init(self): def test_init(self):
self.assertEqual(self.api.api_key, self.api_key) self.assertEqual(self.api.api_key, self.api_key)
@@ -22,6 +21,70 @@ class TestCrewAPI(unittest.TestCase):
}, },
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.requests.request") @patch("crewai.cli.plus_api.requests.request")
def test_make_request(self, mock_request): def test_make_request(self, mock_request):
mock_response = MagicMock() mock_response = MagicMock()
@@ -49,53 +112,53 @@ class TestCrewAPI(unittest.TestCase):
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_status_by_name(self, mock_make_request): def test_crew_status_by_name(self, mock_make_request):
self.api.status_by_name("test_project") self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status" "GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_status_by_uuid(self, mock_make_request): def test_crew_status_by_uuid(self, mock_make_request):
self.api.status_by_uuid("test_uuid") self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status" "GET", "/crewai_plus/api/v1/crews/test_uuid/status"
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_logs_by_name(self, mock_make_request): def test_crew_by_name(self, mock_make_request):
self.api.logs_by_name("test_project") self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment" "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
) )
self.api.logs_by_name("test_project", "custom_log") self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with( mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log" "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_logs_by_uuid(self, mock_make_request): def test_crew_by_uuid(self, mock_make_request):
self.api.logs_by_uuid("test_uuid") self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment" "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
) )
self.api.logs_by_uuid("test_uuid", "custom_log") self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with( mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log" "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_by_name(self, mock_make_request): def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_by_name("test_project") self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project" "DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
) )
@patch("crewai.cli.plus_api.PlusAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_by_uuid(self, mock_make_request): def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_by_uuid("test_uuid") self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with( mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid" "DELETE", "/crewai_plus/api/v1/crews/test_uuid"
) )
@@ -105,7 +168,7 @@ class TestCrewAPI(unittest.TestCase):
self.api.list_crews() self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews") mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai.cli.deploy.api.CrewAPI._make_request") @patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request): def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"} payload = {"name": "test_crew"}
self.api.create_crew(payload) self.api.create_crew(payload)
@@ -113,9 +176,9 @@ class TestCrewAPI(unittest.TestCase):
"POST", "/crewai_plus/api/v1/crews", json=payload "POST", "/crewai_plus/api/v1/crews", json=payload
) )
@patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"}) @patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
def test_custom_base_url(self): def test_custom_base_url(self):
custom_api = CrewAPI("test_key") custom_api = PlusAPI("test_key")
self.assertEqual( self.assertEqual(
custom_api.base_url, custom_api.base_url,
"https://custom-url.com/api", "https://custom-url.com/api",

View File

@@ -1,69 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
from crewai.cli.tools.api import ToolsAPI
class TestToolsAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = ToolsAPI(self.api_key)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish(handle, public, version, description, encoded_file)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish(handle, public, version, description, encoded_file)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)

View File

@@ -0,0 +1,229 @@
import unittest
import unittest.mock
from crewai.cli.tools.main import ToolCommand
from io import StringIO
from unittest.mock import patch, MagicMock
class TestToolCommand(unittest.TestCase):
@patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_success(self, mock_get, mock_subprocess_run):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {
"handle": "sample-tool",
"repository": {
"handle": "sample-repo",
"url": "https://example.com/repo",
"credentials": "my_very_secret",
},
}
mock_get.return_value = mock_get_response
mock_subprocess_run.return_value = MagicMock(stderr=None)
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
tool_command.install("sample-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("sample-tool")
mock_subprocess_run.assert_any_call(
[
"poetry",
"source",
"add",
"--priority=explicit",
"crewai-sample-repo",
"https://example.com/repo",
],
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
[
"poetry",
"config",
"http-basic.crewai-sample-repo",
"my_very_secret",
'""',
],
capture_output=False,
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
capture_output=False,
text=True,
check=True,
)
self.assertIn("Succesfully installed sample-tool", output)
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_tool_not_found(self, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 404
mock_get.return_value = mock_get_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.install("non-existent-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("non-existent-tool")
self.assertIn("No tool found with this name", output)
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_api_error(self, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 500
mock_get.return_value = mock_get_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.install("error-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("error-tool")
self.assertIn("Failed to get tool details", output)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_success(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"}
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
tool_command.publish(is_public=True)
mock_get_project_name.assert_called_once_with(require=True)
mock_get_project_version.assert_called_once_with(require=True)
mock_get_project_description.assert_called_once_with(require=False)
mock_subprocess_run.assert_called_once_with(
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
check=True,
capture_output=False,
)
mock_open.assert_called_once_with(unittest.mock.ANY, "rb")
mock_publish.assert_called_once_with(
handle="sample-tool",
is_public=True,
version="1.0.0",
description="A sample tool",
encoded_file=unittest.mock.ANY,
)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_failure(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 422
mock_publish_response.json.return_value = {"name": ["is already taken"]}
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.publish(is_public=True)
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
self.assertIn("Name is already taken", output)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_api_error(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 500
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.publish(is_public=True)
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
if __name__ == "__main__":
unittest.main()