mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
CLI for Tool Repository (#1357)
This commit adds two commands to the CLI:
- `crewai tool publish`
- Builds the project using Poetry
- Uploads the tarball to CrewAI's tool repository
- `crewai tool install my-tool`
- Adds my-tool's index to Poetry and its credentials
- Installs my-tool from the custom index
This commit is contained in:
@@ -11,6 +11,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .tools.main import ToolCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .install_crew import install_crew
|
||||
from .replay_from_task import replay_task_command
|
||||
@@ -202,6 +203,12 @@ def deploy():
|
||||
pass
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def tool():
|
||||
"""Tool Repository related commands."""
|
||||
pass
|
||||
|
||||
|
||||
@deploy.command(name="create")
|
||||
@click.option("-y", "--yes", is_flag=True, help="Skip the confirmation prompt")
|
||||
def deploy_create(yes: bool):
|
||||
@@ -249,5 +256,20 @@ def deploy_remove(uuid: Optional[str]):
|
||||
deploy_cmd.remove_crew(uuid=uuid)
|
||||
|
||||
|
||||
@tool.command(name="install")
|
||||
@click.argument("handle")
|
||||
def tool_install(handle: str):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.install(handle)
|
||||
|
||||
|
||||
@tool.command(name="publish")
|
||||
@click.option("--public", "is_public", flag_value=True, default=False)
|
||||
@click.option("--private", "is_public", flag_value=False)
|
||||
def tool_publish(is_public: bool):
|
||||
tool_cmd = ToolCommand()
|
||||
tool_cmd.publish(is_public)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
crewai()
|
||||
|
||||
40
src/crewai/cli/command.py
Normal file
40
src/crewai/cli/command.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from typing import Dict, Any
|
||||
from rich.console import Console
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.utils import get_auth_token
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class BaseCommand:
|
||||
def __init__(self):
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
|
||||
|
||||
class PlusAPIMixin:
|
||||
def __init__(self, telemetry):
|
||||
try:
|
||||
telemetry.set_tracer()
|
||||
self.plus_api_client = PlusAPI(api_key=get_auth_token())
|
||||
except Exception:
|
||||
self._deploy_signup_error_span = telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The JSON response containing error information.
|
||||
"""
|
||||
error = json_response.get("error", "Unknown error")
|
||||
message = json_response.get("message", "No message provided")
|
||||
console.print(f"Error: {error}", style="bold red")
|
||||
console.print(f"Message: {message}", style="bold red")
|
||||
@@ -1,56 +0,0 @@
|
||||
import requests
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class CrewAPI(PlusAPI):
|
||||
"""
|
||||
CrewAPI class to interact with the Crew resource in CrewAI+ API.
|
||||
"""
|
||||
|
||||
RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
|
||||
# Deploy
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.RESOURCE}/by-name/{project_name}/deploy"
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("POST", f"{self.RESOURCE}/{uuid}/deploy")
|
||||
|
||||
# Status
|
||||
def status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.RESOURCE}/by-name/{project_name}/status"
|
||||
)
|
||||
|
||||
def status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/status")
|
||||
|
||||
# Logs
|
||||
def logs_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def logs_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.RESOURCE}/{uuid}/logs/{log_type}")
|
||||
|
||||
# Delete
|
||||
def delete_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"{self.RESOURCE}/by-name/{project_name}")
|
||||
|
||||
def delete_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"{self.RESOURCE}/{uuid}")
|
||||
|
||||
# List
|
||||
def list_crews(self) -> requests.Response:
|
||||
return self._make_request("GET", self.RESOURCE)
|
||||
|
||||
# Create
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.RESOURCE, json=payload)
|
||||
@@ -2,19 +2,17 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
from crewai.telemetry import Telemetry
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import (
|
||||
fetch_and_json_env_file,
|
||||
get_auth_token,
|
||||
get_git_remote_url,
|
||||
get_project_name,
|
||||
)
|
||||
from .api import CrewAPI
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class DeployCommand:
|
||||
class DeployCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle deployment-related operations for CrewAI projects.
|
||||
"""
|
||||
@@ -23,40 +21,10 @@ class DeployCommand:
|
||||
"""
|
||||
Initialize the DeployCommand with project name and API client.
|
||||
"""
|
||||
try:
|
||||
self._telemetry = Telemetry()
|
||||
self._telemetry.set_tracer()
|
||||
access_token = get_auth_token()
|
||||
except Exception:
|
||||
self._deploy_signup_error_span = self._telemetry.deploy_signup_error_span()
|
||||
console.print(
|
||||
"Please sign up/login to CrewAI+ before using the CLI.",
|
||||
style="bold red",
|
||||
)
|
||||
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
|
||||
raise SystemExit
|
||||
|
||||
self.project_name = get_project_name()
|
||||
if self.project_name is None:
|
||||
console.print(
|
||||
"No project name found. Please ensure your project has a valid pyproject.toml file.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self.client = CrewAPI(api_key=access_token)
|
||||
|
||||
def _handle_error(self, json_response: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Handle and display error messages from API responses.
|
||||
|
||||
Args:
|
||||
json_response (Dict[str, Any]): The JSON response containing error information.
|
||||
"""
|
||||
error = json_response.get("error", "Unknown error")
|
||||
message = json_response.get("message", "No message provided")
|
||||
console.print(f"Error: {error}", style="bold red")
|
||||
console.print(f"Message: {message}", style="bold red")
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
self.project_name = get_project_name(require=True)
|
||||
|
||||
def _standard_no_param_error_message(self) -> None:
|
||||
"""
|
||||
@@ -104,9 +72,9 @@ class DeployCommand:
|
||||
self._start_deployment_span = self._telemetry.start_deployment_span(uuid)
|
||||
console.print("Starting deployment...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.client.deploy_by_uuid(uuid)
|
||||
response = self.plus_api_client.deploy_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.deploy_by_name(self.project_name)
|
||||
response = self.plus_api_client.deploy_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -115,7 +83,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_deployment_info(json_response)
|
||||
else:
|
||||
self._handle_error(json_response)
|
||||
self._handle_plus_api_error(json_response)
|
||||
|
||||
def create_crew(self, confirm: bool = False) -> None:
|
||||
"""
|
||||
@@ -139,11 +107,11 @@ class DeployCommand:
|
||||
self._confirm_input(env_vars, remote_repo_url, confirm)
|
||||
payload = self._create_payload(env_vars, remote_repo_url)
|
||||
|
||||
response = self.client.create_crew(payload)
|
||||
response = self.plus_api_client.create_crew(payload)
|
||||
if response.status_code == 201:
|
||||
self._display_creation_success(response.json())
|
||||
else:
|
||||
self._handle_error(response.json())
|
||||
self._handle_plus_api_error(response.json())
|
||||
|
||||
def _confirm_input(
|
||||
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
|
||||
@@ -208,7 +176,7 @@ class DeployCommand:
|
||||
"""
|
||||
console.print("Listing all Crews\n", style="bold blue")
|
||||
|
||||
response = self.client.list_crews()
|
||||
response = self.plus_api_client.list_crews()
|
||||
json_response = response.json()
|
||||
if response.status_code == 200:
|
||||
self._display_crews(json_response)
|
||||
@@ -243,9 +211,9 @@ class DeployCommand:
|
||||
"""
|
||||
console.print("Fetching deployment status...", style="bold blue")
|
||||
if uuid:
|
||||
response = self.client.status_by_uuid(uuid)
|
||||
response = self.plus_api_client.crew_status_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.status_by_name(self.project_name)
|
||||
response = self.plus_api_client.crew_status_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -254,7 +222,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_crew_status(json_response)
|
||||
else:
|
||||
self._handle_error(json_response)
|
||||
self._handle_plus_api_error(json_response)
|
||||
|
||||
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
|
||||
"""
|
||||
@@ -278,9 +246,9 @@ class DeployCommand:
|
||||
console.print(f"Fetching {log_type} logs...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.client.logs_by_uuid(uuid, log_type)
|
||||
response = self.plus_api_client.crew_by_uuid(uuid, log_type)
|
||||
elif self.project_name:
|
||||
response = self.client.logs_by_name(self.project_name, log_type)
|
||||
response = self.plus_api_client.crew_by_name(self.project_name, log_type)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
@@ -288,7 +256,7 @@ class DeployCommand:
|
||||
if response.status_code == 200:
|
||||
self._display_logs(response.json())
|
||||
else:
|
||||
self._handle_error(response.json())
|
||||
self._handle_plus_api_error(response.json())
|
||||
|
||||
def remove_crew(self, uuid: Optional[str]) -> None:
|
||||
"""
|
||||
@@ -301,9 +269,9 @@ class DeployCommand:
|
||||
console.print("Removing deployment...", style="bold blue")
|
||||
|
||||
if uuid:
|
||||
response = self.client.delete_by_uuid(uuid)
|
||||
response = self.plus_api_client.delete_crew_by_uuid(uuid)
|
||||
elif self.project_name:
|
||||
response = self.client.delete_by_name(self.project_name)
|
||||
response = self.plus_api_client.delete_crew_by_name(self.project_name)
|
||||
else:
|
||||
self._standard_no_param_error_message()
|
||||
return
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from typing import Optional
|
||||
import requests
|
||||
from os import getenv
|
||||
from crewai.cli.utils import get_crewai_version
|
||||
@@ -9,6 +10,9 @@ class PlusAPI:
|
||||
This class exposes methods for working with the CrewAI+ API.
|
||||
"""
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
self.headers = {
|
||||
@@ -22,3 +26,67 @@ class PlusAPI:
|
||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||
url = urljoin(self.base_url, endpoint)
|
||||
return requests.request(method, url, headers=self.headers, **kwargs)
|
||||
|
||||
def get_tool(self, handle: str):
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
handle: str,
|
||||
is_public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
encoded_file: str,
|
||||
):
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": is_public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
def deploy_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"POST", f"{self.CREWS_RESOURCE}/by-name/{project_name}/deploy"
|
||||
)
|
||||
|
||||
def deploy_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("POST", f"{self.CREWS_RESOURCE}/{uuid}/deploy")
|
||||
|
||||
def crew_status_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/status"
|
||||
)
|
||||
|
||||
def crew_status_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("GET", f"{self.CREWS_RESOURCE}/{uuid}/status")
|
||||
|
||||
def crew_by_name(
|
||||
self, project_name: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/by-name/{project_name}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def crew_by_uuid(
|
||||
self, uuid: str, log_type: str = "deployment"
|
||||
) -> requests.Response:
|
||||
return self._make_request(
|
||||
"GET", f"{self.CREWS_RESOURCE}/{uuid}/logs/{log_type}"
|
||||
)
|
||||
|
||||
def delete_crew_by_name(self, project_name: str) -> requests.Response:
|
||||
return self._make_request(
|
||||
"DELETE", f"{self.CREWS_RESOURCE}/by-name/{project_name}"
|
||||
)
|
||||
|
||||
def delete_crew_by_uuid(self, uuid: str) -> requests.Response:
|
||||
return self._make_request("DELETE", f"{self.CREWS_RESOURCE}/{uuid}")
|
||||
|
||||
def list_crews(self) -> requests.Response:
|
||||
return self._make_request("GET", self.CREWS_RESOURCE)
|
||||
|
||||
def create_crew(self, payload) -> requests.Response:
|
||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||
|
||||
@@ -1,26 +0,0 @@
|
||||
from typing import Optional
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class ToolsAPI(PlusAPI):
|
||||
RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
|
||||
def get(self, handle: str):
|
||||
return self._make_request("GET", f"{self.RESOURCE}/{handle}")
|
||||
|
||||
def publish(
|
||||
self,
|
||||
handle: str,
|
||||
public: bool,
|
||||
version: str,
|
||||
description: Optional[str],
|
||||
encoded_file: str,
|
||||
):
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
return self._make_request("POST", f"{self.RESOURCE}", json=params)
|
||||
168
src/crewai/cli/tools/main.py
Normal file
168
src/crewai/cli/tools/main.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import base64
|
||||
import click
|
||||
import os
|
||||
import subprocess
|
||||
import tempfile
|
||||
|
||||
from crewai.cli.command import BaseCommand, PlusAPIMixin
|
||||
from crewai.cli.utils import (
|
||||
get_project_name,
|
||||
get_project_description,
|
||||
get_project_version,
|
||||
)
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"""
|
||||
A class to handle tool repository related operations for CrewAI projects.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
BaseCommand.__init__(self)
|
||||
PlusAPIMixin.__init__(self, telemetry=self._telemetry)
|
||||
|
||||
def publish(self, is_public: bool):
|
||||
project_name = get_project_name(require=True)
|
||||
assert isinstance(project_name, str)
|
||||
|
||||
project_version = get_project_version(require=True)
|
||||
assert isinstance(project_version, str)
|
||||
|
||||
project_description = get_project_description(require=False)
|
||||
encoded_tarball = None
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
subprocess.run(
|
||||
["poetry", "build", "-f", "sdist", "--output", temp_build_dir],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
|
||||
tarball_filename = next(
|
||||
(f for f in os.listdir(temp_build_dir) if f.endswith(".tar.gz")), None
|
||||
)
|
||||
if not tarball_filename:
|
||||
console.print(
|
||||
"Project build failed. Please ensure that the command `poetry build -f sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
tarball_contents = file.read()
|
||||
|
||||
encoded_tarball = base64.b64encode(tarball_contents).decode("utf-8")
|
||||
|
||||
publish_response = self.plus_api_client.publish_tool(
|
||||
handle=project_name,
|
||||
is_public=is_public,
|
||||
version=project_version,
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
)
|
||||
if publish_response.status_code == 422:
|
||||
console.print(
|
||||
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
|
||||
)
|
||||
for field, messages in publish_response.json().items():
|
||||
for message in messages:
|
||||
console.print(
|
||||
f"* [bold red]{field.capitalize()}[/bold red] {message}"
|
||||
)
|
||||
|
||||
raise SystemExit
|
||||
elif publish_response.status_code != 200:
|
||||
self._handle_plus_api_error(publish_response.json())
|
||||
console.print(
|
||||
"Failed to publish tool. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
published_handle = publish_response.json()["handle"]
|
||||
console.print(
|
||||
f"Succesfully published {published_handle} ({project_version}).\nInstall it in other projects with crewai tool install {published_handle}",
|
||||
style="bold green",
|
||||
)
|
||||
|
||||
def install(self, handle: str):
|
||||
get_response = self.plus_api_client.get_tool(handle)
|
||||
|
||||
if get_response.status_code == 404:
|
||||
console.print(
|
||||
"No tool found with this name. Please ensure the tool was published and you have access to it.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
elif get_response.status_code != 200:
|
||||
console.print(
|
||||
"Failed to get tool details. Please try again later.", style="bold red"
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
self._add_repository_to_poetry(get_response.json())
|
||||
self._add_package(get_response.json())
|
||||
|
||||
console.print(f"Succesfully installed {handle}", style="bold green")
|
||||
|
||||
def _add_repository_to_poetry(self, tool_details):
|
||||
repository_handle = f"crewai-{tool_details['repository']['handle']}"
|
||||
repository_url = tool_details["repository"]["url"]
|
||||
repository_credentials = tool_details["repository"]["credentials"]
|
||||
|
||||
add_repository_command = [
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
repository_handle,
|
||||
repository_url,
|
||||
]
|
||||
add_repository_result = subprocess.run(
|
||||
add_repository_command, text=True, check=True
|
||||
)
|
||||
|
||||
if add_repository_result.stderr:
|
||||
click.echo(add_repository_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
add_repository_credentials_command = [
|
||||
"poetry",
|
||||
"config",
|
||||
f"http-basic.{repository_handle}",
|
||||
repository_credentials,
|
||||
'""',
|
||||
]
|
||||
add_repository_credentials_result = subprocess.run(
|
||||
add_repository_credentials_command,
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
if add_repository_credentials_result.stderr:
|
||||
click.echo(add_repository_credentials_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
|
||||
def _add_package(self, tool_details):
|
||||
tool_handle = tool_details["handle"]
|
||||
repository_handle = tool_details["repository"]["handle"]
|
||||
pypi_index_handle = f"crewai-{repository_handle}"
|
||||
|
||||
add_package_command = [
|
||||
"poetry",
|
||||
"add",
|
||||
"--source",
|
||||
pypi_index_handle,
|
||||
tool_handle,
|
||||
]
|
||||
add_package_result = subprocess.run(
|
||||
add_package_command, capture_output=False, text=True, check=True
|
||||
)
|
||||
|
||||
if add_package_result.stderr:
|
||||
click.echo(add_package_result.stderr, err=True)
|
||||
raise SystemExit
|
||||
@@ -3,8 +3,10 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
from rich.console import Console
|
||||
from crewai.cli.authentication.utils import TokenManager
|
||||
from functools import reduce
|
||||
from rich.console import Console
|
||||
from typing import Any, Dict, List
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
import tomllib
|
||||
@@ -88,21 +90,51 @@ def get_git_remote_url() -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
|
||||
def get_project_name(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project name from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "name"], require=require
|
||||
)
|
||||
|
||||
|
||||
def get_project_version(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project version from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "version"], require=require
|
||||
)
|
||||
|
||||
|
||||
def get_project_description(
|
||||
pyproject_path: str = "pyproject.toml", require: bool = False
|
||||
) -> str | None:
|
||||
"""Get the project description from the pyproject.toml file."""
|
||||
return _get_project_attribute(
|
||||
pyproject_path, ["tool", "poetry", "description"], require=require
|
||||
)
|
||||
|
||||
|
||||
def _get_project_attribute(
|
||||
pyproject_path: str, keys: List[str], require: bool
|
||||
) -> Any | None:
|
||||
"""Get an attribute from the pyproject.toml file."""
|
||||
attribute = None
|
||||
|
||||
try:
|
||||
# Read the pyproject.toml file
|
||||
with open(pyproject_path, "r") as f:
|
||||
pyproject_content = parse_toml(f.read())
|
||||
|
||||
# Extract the project name
|
||||
project_name = pyproject_content["tool"]["poetry"]["name"]
|
||||
|
||||
if "crewai" not in pyproject_content["tool"]["poetry"]["dependencies"]:
|
||||
dependencies = (
|
||||
_get_nested_value(pyproject_content, ["tool", "poetry", "dependencies"])
|
||||
or {}
|
||||
)
|
||||
if "crewai" not in dependencies:
|
||||
raise Exception("crewai is not in the dependencies.")
|
||||
|
||||
return project_name
|
||||
|
||||
attribute = _get_nested_value(pyproject_content, keys)
|
||||
except FileNotFoundError:
|
||||
print(f"Error: {pyproject_path} not found.")
|
||||
except KeyError:
|
||||
@@ -116,7 +148,18 @@ def get_project_name(pyproject_path: str = "pyproject.toml") -> str | None:
|
||||
except Exception as e:
|
||||
print(f"Error reading the pyproject.toml file: {e}")
|
||||
|
||||
return None
|
||||
if require and not attribute:
|
||||
console.print(
|
||||
f"Unable to read '{'.'.join(keys)}' in the pyproject.toml file. Please verify that the file exists and contains the specified attribute.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
return attribute
|
||||
|
||||
|
||||
def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
|
||||
return reduce(dict.__getitem__, keys, data)
|
||||
|
||||
|
||||
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:
|
||||
|
||||
@@ -8,34 +8,34 @@ from crewai.cli.utils import parse_toml
|
||||
|
||||
|
||||
class TestDeployCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.deploy.main.get_auth_token")
|
||||
@patch("crewai.cli.command.get_auth_token")
|
||||
@patch("crewai.cli.deploy.main.get_project_name")
|
||||
@patch("crewai.cli.deploy.main.CrewAPI")
|
||||
def setUp(self, mock_crew_api, mock_get_project_name, mock_get_auth_token):
|
||||
@patch("crewai.cli.command.PlusAPI")
|
||||
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
|
||||
self.mock_get_auth_token = mock_get_auth_token
|
||||
self.mock_get_project_name = mock_get_project_name
|
||||
self.mock_crew_api = mock_crew_api
|
||||
self.mock_plus_api = mock_plus_api
|
||||
|
||||
self.mock_get_auth_token.return_value = "test_token"
|
||||
self.mock_get_project_name.return_value = "test_project"
|
||||
|
||||
self.deploy_command = DeployCommand()
|
||||
self.mock_client = self.deploy_command.client
|
||||
self.mock_client = self.deploy_command.plus_api_client
|
||||
|
||||
def test_init_success(self):
|
||||
self.assertEqual(self.deploy_command.project_name, "test_project")
|
||||
self.mock_crew_api.assert_called_once_with(api_key="test_token")
|
||||
self.mock_plus_api.assert_called_once_with(api_key="test_token")
|
||||
|
||||
@patch("crewai.cli.deploy.main.get_auth_token")
|
||||
@patch("crewai.cli.command.get_auth_token")
|
||||
def test_init_failure(self, mock_get_auth_token):
|
||||
mock_get_auth_token.side_effect = Exception("Auth failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
DeployCommand()
|
||||
|
||||
def test_handle_error(self):
|
||||
def test_handle_plus_api_error(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._handle_error(
|
||||
self.deploy_command._handle_plus_api_error(
|
||||
{"error": "Test error", "message": "Test message"}
|
||||
)
|
||||
self.assertIn("Error: Test error", fake_out.getvalue())
|
||||
@@ -122,7 +122,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "TestCrew", "status": "active"}
|
||||
self.mock_client.status_by_name.return_value = mock_response
|
||||
self.mock_client.crew_status_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_status()
|
||||
@@ -136,7 +136,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
|
||||
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
|
||||
]
|
||||
self.mock_client.logs_by_name.return_value = mock_response
|
||||
self.mock_client.crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_logs(None)
|
||||
@@ -146,7 +146,7 @@ class TestDeployCommand(unittest.TestCase):
|
||||
def test_remove_crew(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
self.mock_client.delete_by_name.return_value = mock_response
|
||||
self.mock_client.delete_crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.remove_crew(None)
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import os
|
||||
import unittest
|
||||
from os import environ
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.cli.deploy.api import CrewAPI
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class TestCrewAPI(unittest.TestCase):
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = CrewAPI(self.api_key)
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
@@ -22,6 +21,70 @@ class TestCrewAPI(unittest.TestCase):
|
||||
},
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = False
|
||||
version = "2.0.0"
|
||||
description = None
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.requests.request")
|
||||
def test_make_request(self, mock_request):
|
||||
mock_response = MagicMock()
|
||||
@@ -49,53 +112,53 @@ class TestCrewAPI(unittest.TestCase):
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_status_by_name(self, mock_make_request):
|
||||
self.api.status_by_name("test_project")
|
||||
def test_crew_status_by_name(self, mock_make_request):
|
||||
self.api.crew_status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_status_by_uuid(self, mock_make_request):
|
||||
self.api.status_by_uuid("test_uuid")
|
||||
def test_crew_status_by_uuid(self, mock_make_request):
|
||||
self.api.crew_status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_logs_by_name(self, mock_make_request):
|
||||
self.api.logs_by_name("test_project")
|
||||
def test_crew_by_name(self, mock_make_request):
|
||||
self.api.crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.logs_by_name("test_project", "custom_log")
|
||||
self.api.crew_by_name("test_project", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_logs_by_uuid(self, mock_make_request):
|
||||
self.api.logs_by_uuid("test_uuid")
|
||||
def test_crew_by_uuid(self, mock_make_request):
|
||||
self.api.crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.logs_by_uuid("test_uuid", "custom_log")
|
||||
self.api.crew_by_uuid("test_uuid", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_by_name(self, mock_make_request):
|
||||
self.api.delete_by_name("test_project")
|
||||
def test_delete_crew_by_name(self, mock_make_request):
|
||||
self.api.delete_crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
|
||||
)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_by_uuid(self, mock_make_request):
|
||||
self.api.delete_by_uuid("test_uuid")
|
||||
def test_delete_crew_by_uuid(self, mock_make_request):
|
||||
self.api.delete_crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
|
||||
)
|
||||
@@ -105,7 +168,7 @@ class TestCrewAPI(unittest.TestCase):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
|
||||
|
||||
@patch("crewai.cli.deploy.api.CrewAPI._make_request")
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
@@ -113,9 +176,9 @@ class TestCrewAPI(unittest.TestCase):
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch.dict(environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
||||
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
||||
def test_custom_base_url(self):
|
||||
custom_api = CrewAPI("test_key")
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url.com/api",
|
||||
@@ -1,69 +0,0 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai.cli.tools.api import ToolsAPI
|
||||
|
||||
|
||||
class TestToolsAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = ToolsAPI(self.api_key)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get("test_tool_handle")
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish(handle, public, version, description, encoded_file)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = False
|
||||
version = "2.0.0"
|
||||
description = None
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish(handle, public, version, description, encoded_file)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
229
tests/cli/tools/test_main.py
Normal file
229
tests/cli/tools/test_main.py
Normal file
@@ -0,0 +1,229 @@
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from crewai.cli.tools.main import ToolCommand
|
||||
from io import StringIO
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
|
||||
class TestToolCommand(unittest.TestCase):
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success(self, mock_get, mock_subprocess_run):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {
|
||||
"handle": "sample-repo",
|
||||
"url": "https://example.com/repo",
|
||||
"credentials": "my_very_secret",
|
||||
},
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
tool_command.install("sample-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("sample-tool")
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"source",
|
||||
"add",
|
||||
"--priority=explicit",
|
||||
"crewai-sample-repo",
|
||||
"https://example.com/repo",
|
||||
],
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"poetry",
|
||||
"config",
|
||||
"http-basic.crewai-sample-repo",
|
||||
"my_very_secret",
|
||||
'""',
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
mock_subprocess_run.assert_any_call(
|
||||
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
self.assertIn("Succesfully installed sample-tool", output)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
self.assertIn("No tool found with this name", output)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(self, mock_get):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
self.assertIn("Failed to get tool details", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_success(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_once_with(require=True)
|
||||
mock_get_project_version.assert_called_once_with(require=True)
|
||||
mock_get_project_description.assert_called_once_with(require=False)
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["poetry", "build", "-f", "sdist", "--output", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
mock_open.assert_called_once_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_once_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_failure(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
self.assertIn("Name is already taken", output)
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.get_project_description", return_value="A sample tool"
|
||||
)
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch(
|
||||
"crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
def test_publish_api_error(
|
||||
self,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 500
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command = ToolCommand()
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with self.assertRaises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = fake_out.getvalue()
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
self.assertIn("Failed to publish tool", output)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user