CLI for Tool Repository (#1357)

This commit adds two commands to the CLI:

- `crewai tool publish`
    - Builds the project using Poetry
    - Uploads the tarball to CrewAI's tool repository

- `crewai tool install my-tool`
    - Adds my-tool's index to Poetry and its credentials
    - Installs my-tool from the custom index
This commit is contained in:
Vini Brasil
2024-09-26 17:23:31 -03:00
committed by GitHub
parent 104ef7a0c2
commit c3ac3219fe
12 changed files with 696 additions and 246 deletions

View File

@@ -11,6 +11,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
from .authentication.main import AuthenticationCommand
from .deploy.main import DeployCommand
from .tools.main import ToolCommand
from .evaluate_crew import evaluate_crew
from .install_crew import install_crew
from .replay_from_task import replay_task_command
@@ -202,6 +203,12 @@ def deploy():
pass
@crewai.group()
def tool():
"""Tool Repository related commands."""
pass
@deploy.command(name="create")
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
def deploy_create(yes: bool):
@@ -249,5 +256,20 @@ def deploy_remove(uuid: Optional[str]):
deploy_cmd.remove_crew(uuid=uuid)
@tool.command(name="install")
@click.argument("handle")
def tool_install(handle: str):
tool_cmd = ToolCommand()
tool_cmd.install(handle)
@tool.command(name="publish")
@click.option("--public", "is_public", flag_value=True, default=False)
@click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool):
tool_cmd = ToolCommand()
tool_cmd.publish(is_public)
if __name__ == "__main__":
crewai()

40
src/crewai/cli/command.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Dict, Any
from rich.console import Console
from crewai.cli.plus_api import PlusAPI
from crewai.cli.utils import get_auth_token
from crewai.telemetry.telemetry import Telemetry
console = Console()
class BaseCommand:
def __init__(self):
self._telemetry = Telemetry()
self._telemetry.set_tracer()
class PlusAPIMixin:
def __init__(self, telemetry):
try:
telemetry.set_tracer()
self.plus_api_client = PlusAPI(api_key=get_auth_token())
except Exception:
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
"""
Handle and display error messages from API responses.
Args:
json_response (Dict[str, Any]): The JSON response containing error information.
"""
error = json_response.get("error", "Unknown error")
message = json_response.get("message", "No message provided")
console.print(f"Error: {error}", style="bold red")
console.print(f"Message: {message}", style="bold red")

View File

@@ -1,56 +0,0 @@
import requests
from crewai.cli.plus_api import PlusAPI
class CrewAPI(PlusAPI):
"""
CrewAPI class to interact with the Crew resource in CrewAI+ API.
"""
RESOURCE = "/crewai_plus/api/v1/crews"
# Deploy
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("POST", f"{self.RESOURCE}/{uuid}/deploy")
# Status
def status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.RESOURCE}/by-name/{project_name}/status"
)
def status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/status")
# Logs
def logs_by_name(
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def logs_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/logs/{log_type}")
# Delete
def delete_by_name(self, project_name: str) -> requests.Response:
return self._make_request("DELETE", f"{self.RESOURCE}/by-name/{project_name}")
def delete_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("DELETE", f"{self.RESOURCE}/{uuid}")
# List
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.RESOURCE)
# Create
def create_crew(self, payload) -> requests.Response:
return self._make_request("POST", self.RESOURCE, json=payload)

View File

@@ -2,19 +2,17 @@ from typing import Any, Dict, List, Optional
from rich.console import Console
from crewai.telemetry import Telemetry
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import (
fetch_and_json_env_file,
get_auth_token,
get_git_remote_url,
get_project_name,
)
from .api import CrewAPI
console = Console()
class DeployCommand:
class DeployCommand(BaseCommand, PlusAPIMixin):
"""
A class to handle deployment-related operations for CrewAI projects.
"""
@@ -23,40 +21,10 @@ class DeployCommand:
"""
Initialize the DeployCommand with project name and API client.
"""
try:
self._telemetry = Telemetry()
self._telemetry.set_tracer()
access_token = get_auth_token()
except Exception:
self._deploy_signup_error_span = self._telemetry.deploy_signup_error_span()
console.print(
"Please sign up/login to CrewAI+ before using the CLI.",
style="bold red",
)
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
self.project_name = get_project_name()
if self.project_name is None:
console.print(
"No project name found. Please ensure your project has a valid pyproject.toml file.",
style="bold red",
)
raise SystemExit
self.client = CrewAPI(api_key=access_token)
def _handle_error(self, json_response: Dict[str, Any]) -> None:
"""
Handle and display error messages from API responses.
Args:
json_response (Dict[str, Any]): The JSON response containing error information.
"""
error = json_response.get("error", "Unknown error")
message = json_response.get("message", "No message provided")
console.print(f"Error: {error}", style="bold red")
console.print(f"Message: {message}", style="bold red")
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
self.project_name = get_project_name(require=True)
def _standard_no_param_error_message(self) -> None:
"""
@@ -104,9 +72,9 @@ class DeployCommand:
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
console.print("Starting deployment...", style="bold blue")
if uuid:
response = self.client.deploy_by_uuid(uuid)
response = self.plus_api_client.deploy_by_uuid(uuid)
elif self.project_name:
response = self.client.deploy_by_name(self.project_name)
response = self.plus_api_client.deploy_by_name(self.project_name)
else:
self._standard_no_param_error_message()
return
@@ -115,7 +83,7 @@ class DeployCommand:
if response.status_code == 200:
self._display_deployment_info(json_response)
else:
self._handle_error(json_response)
self._handle_plus_api_error(json_response)
def create_crew(self, confirm: bool = False) -> None:
"""
@@ -139,11 +107,11 @@ class DeployCommand:
self._confirm_input(env_vars, remote_repo_url, confirm)
payload = self._create_payload(env_vars, remote_repo_url)
response = self.client.create_crew(payload)
response = self.plus_api_client.create_crew(payload)
if response.status_code == 201:
self._display_creation_success(response.json())
else:
self._handle_error(response.json())
self._handle_plus_api_error(response.json())
def _confirm_input(
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
@@ -208,7 +176,7 @@ class DeployCommand:
"""
console.print("Listing all Crews\n", style="bold blue")
response = self.client.list_crews()
response = self.plus_api_client.list_crews()
json_response = response.json()
if response.status_code == 200:
self._display_crews(json_response)
@@ -243,9 +211,9 @@ class DeployCommand:
"""
console.print("Fetching deployment status...", style="bold blue")
if uuid:
response = self.client.status_by_uuid(uuid)
response = self.plus_api_client.crew_status_by_uuid(uuid)
elif self.project_name:
response = self.client.status_by_name(self.project_name)
response = self.plus_api_client.crew_status_by_name(self.project_name)
else:
self._standard_no_param_error_message()
return
@@ -254,7 +222,7 @@ class DeployCommand:
if response.status_code == 200:
self._display_crew_status(json_response)
else:
self._handle_error(json_response)
self._handle_plus_api_error(json_response)
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
"""
@@ -278,9 +246,9 @@ class DeployCommand:
console.print(f"Fetching {log_type} logs...", style="bold blue")
if uuid:
response = self.client.logs_by_uuid(uuid, log_type)
response = self.plus_api_client.crew_by_uuid(uuid, log_type)
elif self.project_name:
response = self.client.logs_by_name(self.project_name, log_type)
response = self.plus_api_client.crew_by_name(self.project_name, log_type)
else:
self._standard_no_param_error_message()
return
@@ -288,7 +256,7 @@ class DeployCommand:
if response.status_code == 200:
self._display_logs(response.json())
else:
self._handle_error(response.json())
self._handle_plus_api_error(response.json())
def remove_crew(self, uuid: Optional[str]) -> None:
"""
@@ -301,9 +269,9 @@ class DeployCommand:
console.print("Removing deployment...", style="bold blue")
if uuid:
response = self.client.delete_by_uuid(uuid)
response = self.plus_api_client.delete_crew_by_uuid(uuid)
elif self.project_name:
response = self.client.delete_by_name(self.project_name)
response = self.plus_api_client.delete_crew_by_name(self.project_name)
else:
self._standard_no_param_error_message()
return

View File

@@ -1,3 +1,4 @@
from typing import Optional
import requests
from os import getenv
from crewai.cli.utils import get_crewai_version
@@ -9,6 +10,9 @@ class PlusAPI:
This class exposes methods for working with the CrewAI+ API.
"""
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
def __init__(self, api_key: str) -> None:
self.api_key = api_key
self.headers = {
@@ -22,3 +26,67 @@ class PlusAPI:
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
url = urljoin(self.base_url, endpoint)
return requests.request(method, url, headers=self.headers, **kwargs)
def get_tool(self, handle: str):
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
def publish_tool(
self,
handle: str,
is_public: bool,
version: str,
description: Optional[str],
encoded_file: str,
):
params = {
"handle": handle,
"public": is_public,
"version": version,
"file": encoded_file,
"description": description,
}
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
def deploy_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
)
def deploy_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
def crew_status_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
)
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
def crew_by_name(
self, project_name: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
)
def crew_by_uuid(
self, uuid: str, log_type: str = "deployment"
) -> requests.Response:
return self._make_request(
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
)
def delete_crew_by_name(self, project_name: str) -> requests.Response:
return self._make_request(
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
)
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
def list_crews(self) -> requests.Response:
return self._make_request("GET", self.CREWS_RESOURCE)
def create_crew(self, payload) -> requests.Response:
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)

View File

@@ -1,26 +0,0 @@
from typing import Optional
from crewai.cli.plus_api import PlusAPI
class ToolsAPI(PlusAPI):
RESOURCE = "/crewai_plus/api/v1/tools"
def get(self, handle: str):
return self._make_request("GET", f"{self.RESOURCE}/{handle}")
def publish(
self,
handle: str,
public: bool,
version: str,
description: Optional[str],
encoded_file: str,
):
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
return self._make_request("POST", f"{self.RESOURCE}", json=params)

View File

@@ -0,0 +1,168 @@
import base64
import click
import os
import subprocess
import tempfile
from crewai.cli.command import BaseCommand, PlusAPIMixin
from crewai.cli.utils import (
get_project_name,
get_project_description,
get_project_version,
)
from rich.console import Console
console = Console()
class ToolCommand(BaseCommand, PlusAPIMixin):
"""
A class to handle tool repository related operations for CrewAI projects.
"""
def __init__(self):
BaseCommand.__init__(self)
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
def publish(self, is_public: bool):
project_name = get_project_name(require=True)
assert isinstance(project_name, str)
project_version = get_project_version(require=True)
assert isinstance(project_version, str)
project_description = get_project_description(require=False)
encoded_tarball = None
with tempfile.TemporaryDirectory() as temp_build_dir:
subprocess.run(
["poetry", "build", "-f", "sdist", "--output", temp_build_dir],
check=True,
capture_output=False,
)
tarball_filename = next(
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
)
if not tarball_filename:
console.print(
"Project build failed. Please ensure that the command `poetry build -f sdist` completes successfully.",
style="bold red",
)
raise SystemExit
tarball_path = os.path.join(temp_build_dir, tarball_filename)
with open(tarball_path, "rb") as file:
tarball_contents = file.read()
encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8")
publish_response = self.plus_api_client.publish_tool(
handle=project_name,
is_public=is_public,
version=project_version,
description=project_description,
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
)
if publish_response.status_code == 422:
console.print(
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
)
for field, messages in publish_response.json().items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
elif publish_response.status_code != 200:
self._handle_plus_api_error(publish_response.json())
console.print(
"Failed to publish tool. Please try again later.", style="bold red"
)
raise SystemExit
published_handle = publish_response.json()["handle"]
console.print(
f"Succesfully published {published_handle} ({project_version}).\nInstall it in other projects with crewai tool install {published_handle}",
style="bold green",
)
def install(self, handle: str):
get_response = self.plus_api_client.get_tool(handle)
if get_response.status_code == 404:
console.print(
"No tool found with this name. Please ensure the tool was published and you have access to it.",
style="bold red",
)
raise SystemExit
elif get_response.status_code != 200:
console.print(
"Failed to get tool details. Please try again later.", style="bold red"
)
raise SystemExit
self._add_repository_to_poetry(get_response.json())
self._add_package(get_response.json())
console.print(f"Succesfully installed {handle}", style="bold green")
def _add_repository_to_poetry(self, tool_details):
repository_handle = f"crewai-{tool_details['repository']['handle']}"
repository_url = tool_details["repository"]["url"]
repository_credentials = tool_details["repository"]["credentials"]
add_repository_command = [
"poetry",
"source",
"add",
"--priority=explicit",
repository_handle,
repository_url,
]
add_repository_result = subprocess.run(
add_repository_command, text=True, check=True
)
if add_repository_result.stderr:
click.echo(add_repository_result.stderr, err=True)
raise SystemExit
add_repository_credentials_command = [
"poetry",
"config",
f"http-basic.{repository_handle}",
repository_credentials,
'""',
]
add_repository_credentials_result = subprocess.run(
add_repository_credentials_command,
capture_output=False,
text=True,
check=True,
)
if add_repository_credentials_result.stderr:
click.echo(add_repository_credentials_result.stderr, err=True)
raise SystemExit
def _add_package(self, tool_details):
tool_handle = tool_details["handle"]
repository_handle = tool_details["repository"]["handle"]
pypi_index_handle = f"crewai-{repository_handle}"
add_package_command = [
"poetry",
"add",
"--source",
pypi_index_handle,
tool_handle,
]
add_package_result = subprocess.run(
add_package_command, capture_output=False, text=True, check=True
)
if add_package_result.stderr:
click.echo(add_package_result.stderr, err=True)
raise SystemExit

View File

@@ -3,8 +3,10 @@ import re
import subprocess
import sys
from rich.console import Console
from crewai.cli.authentication.utils import TokenManager
from functools import reduce
from rich.console import Console
from typing import Any, Dict, List
if sys.version_info >= (3, 11):
import tomllib
@@ -88,21 +90,51 @@ def get_git_remote_url() -> str | None:
return None
def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
def get_project_name(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project name from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "name"], require=require
)
def get_project_version(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project version from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "version"], require=require
)
def get_project_description(
pyproject_path: str = "pyproject.toml", require: bool = False
) -> str | None:
"""Get the project description from the pyproject.toml file."""
return _get_project_attribute(
pyproject_path, ["tool", "poetry", "description"], require=require
)
def _get_project_attribute(
pyproject_path: str, keys: List[str], require: bool
) -> Any | None:
"""Get an attribute from the pyproject.toml file."""
attribute = None
try:
# Read the pyproject.toml file
with open(pyproject_path, "r") as f:
pyproject_content = parse_toml(f.read())
# Extract the project name
project_name = pyproject_content["tool"]["poetry"]["name"]
if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]:
dependencies = (
_get_nested_value(pyproject_content, ["tool", "poetry", "dependencies"])
or {}
)
if "crewai" not in dependencies:
raise Exception("crewai is not in the dependencies.")
return project_name
attribute = _get_nested_value(pyproject_content, keys)
except FileNotFoundError:
print(f"Error: {pyproject_path} not found.")
except KeyError:
@@ -116,7 +148,18 @@ def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
except Exception as e:
print(f"Error reading the pyproject.toml file: {e}")
return None
if require and not attribute:
console.print(
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
style="bold red",
)
raise SystemExit
return attribute
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:

View File

@@ -8,34 +8,34 @@ from crewai.cli.utils import parse_toml
class TestDeployCommand(unittest.TestCase):
@patch("crewai.cli.deploy.main.get_auth_token")
@patch("crewai.cli.command.get_auth_token")
@patch("crewai.cli.deploy.main.get_project_name")
@patch("crewai.cli.deploy.main.CrewAPI")
def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token):
@patch("crewai.cli.command.PlusAPI")
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
self.mock_get_auth_token = mock_get_auth_token
self.mock_get_project_name = mock_get_project_name
self.mock_crew_api = mock_crew_api
self.mock_plus_api = mock_plus_api
self.mock_get_auth_token.return_value = "test_token"
self.mock_get_project_name.return_value = "test_project"
self.deploy_command = DeployCommand()
self.mock_client = self.deploy_command.client
self.mock_client = self.deploy_command.plus_api_client
def test_init_success(self):
self.assertEqual(self.deploy_command.project_name, "test_project")
self.mock_crew_api.assert_called_once_with(api_key="test_token")
self.mock_plus_api.assert_called_once_with(api_key="test_token")
@patch("crewai.cli.deploy.main.get_auth_token")
@patch("crewai.cli.command.get_auth_token")
def test_init_failure(self, mock_get_auth_token):
mock_get_auth_token.side_effect = Exception("Auth failed")
with self.assertRaises(SystemExit):
DeployCommand()
def test_handle_error(self):
def test_handle_plus_api_error(self):
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._handle_error(
self.deploy_command._handle_plus_api_error(
{"error": "Test error", "message": "Test message"}
)
self.assertIn("Error: Test error", fake_out.getvalue())
@@ -122,7 +122,7 @@ class TestDeployCommand(unittest.TestCase):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"name": "TestCrew", "status": "active"}
self.mock_client.status_by_name.return_value = mock_response
self.mock_client.crew_status_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_status()
@@ -136,7 +136,7 @@ class TestDeployCommand(unittest.TestCase):
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
]
self.mock_client.logs_by_name.return_value = mock_response
self.mock_client.crew_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.get_crew_logs(None)
@@ -146,7 +146,7 @@ class TestDeployCommand(unittest.TestCase):
def test_remove_crew(self):
mock_response = MagicMock()
mock_response.status_code = 204
self.mock_client.delete_by_name.return_value = mock_response
self.mock_client.delete_crew_by_name.return_value = mock_response
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command.remove_crew(None)

View File

@@ -1,14 +1,13 @@
import os
import unittest
from os import environ
from unittest.mock import MagicMock, patch
from crewai.cli.deploy.api import CrewAPI
from crewai.cli.plus_api import PlusAPI
class TestCrewAPI(unittest.TestCase):
class TestPlusAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = CrewAPI(self.api_key)
self.api = PlusAPI(self.api_key)
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
@@ -22,6 +21,70 @@ class TestCrewAPI(unittest.TestCase):
},
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get_tool("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish_tool(
handle, public, version, description, encoded_file
)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.requests.request")
def test_make_request(self, mock_request):
mock_response = MagicMock()
@@ -49,53 +112,53 @@ class TestCrewAPI(unittest.TestCase):
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_status_by_name(self, mock_make_request):
self.api.status_by_name("test_project")
def test_crew_status_by_name(self, mock_make_request):
self.api.crew_status_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_status_by_uuid(self, mock_make_request):
self.api.status_by_uuid("test_uuid")
def test_crew_status_by_uuid(self, mock_make_request):
self.api.crew_status_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_logs_by_name(self, mock_make_request):
self.api.logs_by_name("test_project")
def test_crew_by_name(self, mock_make_request):
self.api.crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
)
self.api.logs_by_name("test_project", "custom_log")
self.api.crew_by_name("test_project", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_logs_by_uuid(self, mock_make_request):
self.api.logs_by_uuid("test_uuid")
def test_crew_by_uuid(self, mock_make_request):
self.api.crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
)
self.api.logs_by_uuid("test_uuid", "custom_log")
self.api.crew_by_uuid("test_uuid", "custom_log")
mock_make_request.assert_called_with(
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_by_name(self, mock_make_request):
self.api.delete_by_name("test_project")
def test_delete_crew_by_name(self, mock_make_request):
self.api.delete_crew_by_name("test_project")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_delete_by_uuid(self, mock_make_request):
self.api.delete_by_uuid("test_uuid")
def test_delete_crew_by_uuid(self, mock_make_request):
self.api.delete_crew_by_uuid("test_uuid")
mock_make_request.assert_called_once_with(
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
)
@@ -105,7 +168,7 @@ class TestCrewAPI(unittest.TestCase):
self.api.list_crews()
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_create_crew(self, mock_make_request):
payload = {"name": "test_crew"}
self.api.create_crew(payload)
@@ -113,9 +176,9 @@ class TestCrewAPI(unittest.TestCase):
"POST", "/crewai_plus/api/v1/crews", json=payload
)
@patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
def test_custom_base_url(self):
custom_api = CrewAPI("test_key")
custom_api = PlusAPI("test_key")
self.assertEqual(
custom_api.base_url,
"https://custom-url.com/api",

View File

@@ -1,69 +0,0 @@
import unittest
from unittest.mock import MagicMock, patch
from crewai.cli.tools.api import ToolsAPI
class TestToolsAPI(unittest.TestCase):
def setUp(self):
self.api_key = "test_api_key"
self.api = ToolsAPI(self.api_key)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.get("test_tool_handle")
mock_make_request.assert_called_once_with(
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = True
version = "1.0.0"
description = "Test tool description"
encoded_file = "encoded_test_file"
response = self.api.publish(handle, public, version, description, encoded_file)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_publish_tool_without_description(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
handle = "test_tool_handle"
public = False
version = "2.0.0"
description = None
encoded_file = "encoded_test_file"
response = self.api.publish(handle, public, version, description, encoded_file)
params = {
"handle": handle,
"public": public,
"version": version,
"file": encoded_file,
"description": description,
}
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools", json=params
)
self.assertEqual(response, mock_response)

View File

@@ -0,0 +1,229 @@
import unittest
import unittest.mock
from crewai.cli.tools.main import ToolCommand
from io import StringIO
from unittest.mock import patch, MagicMock
class TestToolCommand(unittest.TestCase):
@patch("crewai.cli.tools.main.subprocess.run")
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_success(self, mock_get, mock_subprocess_run):
mock_get_response = MagicMock()
mock_get_response.status_code = 200
mock_get_response.json.return_value = {
"handle": "sample-tool",
"repository": {
"handle": "sample-repo",
"url": "https://example.com/repo",
"credentials": "my_very_secret",
},
}
mock_get.return_value = mock_get_response
mock_subprocess_run.return_value = MagicMock(stderr=None)
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
tool_command.install("sample-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("sample-tool")
mock_subprocess_run.assert_any_call(
[
"poetry",
"source",
"add",
"--priority=explicit",
"crewai-sample-repo",
"https://example.com/repo",
],
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
[
"poetry",
"config",
"http-basic.crewai-sample-repo",
"my_very_secret",
'""',
],
capture_output=False,
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
capture_output=False,
text=True,
check=True,
)
self.assertIn("Succesfully installed sample-tool", output)
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_tool_not_found(self, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 404
mock_get.return_value = mock_get_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.install("non-existent-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("non-existent-tool")
self.assertIn("No tool found with this name", output)
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
def test_install_api_error(self, mock_get):
mock_get_response = MagicMock()
mock_get_response.status_code = 500
mock_get.return_value = mock_get_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.install("error-tool")
output = fake_out.getvalue()
mock_get.assert_called_once_with("error-tool")
self.assertIn("Failed to get tool details", output)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_success(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 200
mock_publish_response.json.return_value = {"handle": "sample-tool"}
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
tool_command.publish(is_public=True)
mock_get_project_name.assert_called_once_with(require=True)
mock_get_project_version.assert_called_once_with(require=True)
mock_get_project_description.assert_called_once_with(require=False)
mock_subprocess_run.assert_called_once_with(
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
check=True,
capture_output=False,
)
mock_open.assert_called_once_with(unittest.mock.ANY, "rb")
mock_publish.assert_called_once_with(
handle="sample-tool",
is_public=True,
version="1.0.0",
description="A sample tool",
encoded_file=unittest.mock.ANY,
)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_failure(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 422
mock_publish_response.json.return_value = {"name": ["is already taken"]}
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.publish(is_public=True)
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
self.assertIn("Name is already taken", output)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
@patch(
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
)
@patch("crewai.cli.tools.main.subprocess.run")
@patch(
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
)
@patch(
"crewai.cli.tools.main.open",
new_callable=unittest.mock.mock_open,
read_data=b"sample tarball content",
)
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
def test_publish_api_error(
self,
mock_publish,
mock_open,
mock_listdir,
mock_subprocess_run,
mock_get_project_description,
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 500
mock_publish.return_value = mock_publish_response
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
with self.assertRaises(SystemExit):
tool_command.publish(is_public=True)
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
if __name__ == "__main__":
unittest.main()