Change Tool Repository authentication scope (#1378)

This commit adds a new command for adding custom PyPI indexes
credentials to the project. This was changed because credentials are now
user-scoped instead of organization.
This commit is contained in:
Vini Brasil
2024-10-01 18:44:08 -03:00
committed by GitHub
parent 94f148e524
commit 2c74efc8f2
9 changed files with 220 additions and 150 deletions

View File

@@ -264,6 +264,7 @@ def deploy_remove(uuid: Optional[str]):
@click.argument("handle")
def tool_install(handle: str):
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.install(handle)
@@ -272,6 +273,7 @@ def tool_install(handle: str):
@click.option("--private", "is_public", flag_value=False)
def tool_publish(is_public: bool):
tool_cmd = ToolCommand()
tool_cmd.login()
tool_cmd.publish(is_public)

View File

@@ -1,4 +1,5 @@
from typing import Dict, Any
import requests
from requests.exceptions import JSONDecodeError
from rich.console import Console
from crewai.cli.plus_api import PlusAPI
from crewai.cli.utils import get_auth_token
@@ -27,14 +28,44 @@ class PlusAPIMixin:
console.print("Run 'crewai signup' to sign up/login.", style="bold green")
raise SystemExit
def _handle_plus_api_error(self, json_response: Dict[str, Any]) -> None:
def _validate_response(self, response: requests.Response) -> None:
"""
Handle and display error messages from API responses.
Args:
json_response (Dict[str, Any]): The JSON response containing error information.
response (requests.Response): The response from the Plus API
"""
error = json_response.get("error", "Unknown error")
message = json_response.get("message", "No message provided")
console.print(f"Error: {error}", style="bold red")
console.print(f"Message: {message}", style="bold red")
try:
json_response = response.json()
except (JSONDecodeError, ValueError):
console.print(
"Failed to parse response from Enterprise API failed. Details:",
style="bold red",
)
console.print(f"Status Code: {response.status_code}")
console.print(f"Response:\n{response.content}")
raise SystemExit
if response.status_code == 422:
console.print(
"Failed to complete operation. Please fix the following errors:",
style="bold red",
)
for field, messages in json_response.items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
if not response.ok:
console.print(
"Request to Enterprise API failed. Details:", style="bold red"
)
details = (
json_response.get("error")
or json_response.get("message")
or response.content
)
console.print(f"{details}")
raise SystemExit

View File

@@ -79,11 +79,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._standard_no_param_error_message()
return
json_response = response.json()
if response.status_code == 200:
self._display_deployment_info(json_response)
else:
self._handle_plus_api_error(json_response)
self._validate_response(response)
self._display_deployment_info(response.json())
def create_crew(self, confirm: bool = False) -> None:
"""
@@ -106,12 +103,10 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._confirm_input(env_vars, remote_repo_url, confirm)
payload = self._create_payload(env_vars, remote_repo_url)
response = self.plus_api_client.create_crew(payload)
if response.status_code == 201:
self._display_creation_success(response.json())
else:
self._handle_plus_api_error(response.json())
self._validate_response(response)
self._display_creation_success(response.json())
def _confirm_input(
self, env_vars: Dict[str, str], remote_repo_url: str, confirm: bool
@@ -218,11 +213,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._standard_no_param_error_message()
return
json_response = response.json()
if response.status_code == 200:
self._display_crew_status(json_response)
else:
self._handle_plus_api_error(json_response)
self._validate_response(response)
self._display_crew_status(response.json())
def _display_crew_status(self, status_data: Dict[str, str]) -> None:
"""
@@ -253,10 +245,8 @@ class DeployCommand(BaseCommand, PlusAPIMixin):
self._standard_no_param_error_message()
return
if response.status_code == 200:
self._display_logs(response.json())
else:
self._handle_plus_api_error(response.json())
self._validate_response(response)
self._display_logs(response.json())
def remove_crew(self, uuid: Optional[str]) -> None:
"""

View File

@@ -27,6 +27,9 @@ class PlusAPI:
url = urljoin(self.base_url, endpoint)
return requests.request(method, url, headers=self.headers, **kwargs)
def login_to_tool_repository(self):
return self._make_request("POST", f"{self.TOOLS_RESOURCE}/login")
def get_tool(self, handle: str):
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")

View File

@@ -64,23 +64,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
description=project_description,
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
)
if publish_response.status_code == 422:
console.print(
"[bold red]Failed to publish tool. Please fix the following errors:[/bold red]"
)
for field, messages in publish_response.json().items():
for message in messages:
console.print(
f"* [bold red]{field.capitalize()}[/bold red] {message}"
)
raise SystemExit
elif publish_response.status_code != 200:
self._handle_plus_api_error(publish_response.json())
console.print(
"Failed to publish tool. Please try again later.", style="bold red"
)
raise SystemExit
self._validate_response(publish_response)
published_handle = publish_response.json()["handle"]
console.print(
@@ -103,15 +88,32 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
)
raise SystemExit
self._add_repository_to_poetry(get_response.json())
self._add_package(get_response.json())
console.print(f"Succesfully installed {handle}", style="bold green")
def _add_repository_to_poetry(self, tool_details):
repository_handle = f"crewai-{tool_details['repository']['handle']}"
repository_url = tool_details["repository"]["url"]
repository_credentials = tool_details["repository"]["credentials"]
def login(self):
login_response = self.plus_api_client.login_to_tool_repository()
if login_response.status_code != 200:
console.print(
"Failed to authenticate to the tool repository. Make sure you have the access to tools.",
style="bold red",
)
raise SystemExit
login_response_json = login_response.json()
for repository in login_response_json["repositories"]:
self._add_repository_to_poetry(
repository, login_response_json["credential"]
)
console.print(
"Succesfully authenticated to the tool repository.", style="bold green"
)
def _add_repository_to_poetry(self, repository, credentials):
repository_handle = f"crewai-{repository['handle']}"
add_repository_command = [
"poetry",
@@ -119,7 +121,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
"add",
"--priority=explicit",
repository_handle,
repository_url,
repository["url"],
]
add_repository_result = subprocess.run(
add_repository_command, text=True, check=True
@@ -133,8 +135,8 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
"poetry",
"config",
f"http-basic.{repository_handle}",
repository_credentials,
'""',
credentials["username"],
credentials["password"],
]
add_repository_credentials_result = subprocess.run(
add_repository_credentials_command,

View File

@@ -2,6 +2,7 @@ import click
import re
import subprocess
import sys
import importlib.metadata
from crewai.cli.authentication.utils import TokenManager
from functools import reduce
@@ -162,29 +163,9 @@ def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any:
return reduce(dict.__getitem__, keys, data)
def get_crewai_version(poetry_lock_path: str = "poetry.lock") -> str:
"""Get the version number of crewai from the poetry.lock file."""
try:
with open(poetry_lock_path, "r") as f:
lock_content = f.read()
match = re.search(
r'\[\[package\]\]\s*name\s*=\s*"crewai"\s*version\s*=\s*"([^"]+)"',
lock_content,
re.DOTALL,
)
if match:
return match.group(1)
else:
print("crewai package not found in poetry.lock")
return "no-version-found"
except FileNotFoundError:
print(f"Error: {poetry_lock_path} not found.")
except Exception as e:
print(f"Error reading the poetry.lock file: {e}")
return "no-version-found"
def get_crewai_version() -> str:
"""Get the version number of CrewAI running the CLI"""
return importlib.metadata.version("crewai")
def fetch_and_json_env_file(env_file_path: str = ".env") -> dict:

View File

@@ -1,7 +1,11 @@
import unittest
from io import StringIO
from unittest.mock import MagicMock, patch
import pytest
import requests
import sys
import unittest
from io import StringIO
from requests.exceptions import JSONDecodeError
from unittest.mock import MagicMock, Mock, patch
from crewai.cli.deploy.main import DeployCommand
from crewai.cli.utils import parse_toml
@@ -33,13 +37,65 @@ class TestDeployCommand(unittest.TestCase):
with self.assertRaises(SystemExit):
DeployCommand()
def test_handle_plus_api_error(self):
def test_validate_response_successful_response(self):
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"message": "Success"}
mock_response.status_code = 200
mock_response.ok = True
with patch("sys.stdout", new=StringIO()) as fake_out:
self.deploy_command._handle_plus_api_error(
{"error": "Test error", "message": "Test message"}
self.deploy_command._validate_response(mock_response)
assert fake_out.getvalue() == ""
def test_validate_response_json_decode_error(self):
mock_response = Mock(spec=requests.Response)
mock_response.json.side_effect = JSONDecodeError("Decode error", "", 0)
mock_response.status_code = 500
mock_response.content = b"Invalid JSON"
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):
self.deploy_command._validate_response(mock_response)
output = fake_out.getvalue()
assert (
"Failed to parse response from Enterprise API failed. Details:"
in output
)
self.assertIn("Error: Test error", fake_out.getvalue())
self.assertIn("Message: Test message", fake_out.getvalue())
assert "Status Code: 500" in output
assert "Response:\nb'Invalid JSON'" in output
def test_validate_response_422_error(self):
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {
"field1": ["Error message 1"],
"field2": ["Error message 2"],
}
mock_response.status_code = 422
mock_response.ok = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):
self.deploy_command._validate_response(mock_response)
output = fake_out.getvalue()
assert (
"Failed to complete operation. Please fix the following errors:"
in output
)
assert "Field1 Error message 1" in output
assert "Field2 Error message 2" in output
def test_validate_response_other_error(self):
mock_response = Mock(spec=requests.Response)
mock_response.json.return_value = {"error": "Something went wrong"}
mock_response.status_code = 500
mock_response.ok = False
with patch("sys.stdout", new=StringIO()) as fake_out:
with pytest.raises(SystemExit):
self.deploy_command._validate_response(mock_response)
output = fake_out.getvalue()
assert "Request to Enterprise API failed. Details:" in output
assert "Details:\nSomething went wrong" in output
def test_standard_no_param_error_message(self):
with patch("sys.stdout", new=StringIO()) as fake_out:
@@ -207,30 +263,7 @@ class TestDeployCommand(unittest.TestCase):
project_name = get_project_name()
self.assertEqual(project_name, "test_project")
@patch(
"builtins.open",
new_callable=unittest.mock.mock_open,
read_data="""
[[package]]
name = "crewai"
version = "0.51.1"
description = "Some description"
category = "main"
optional = false
python-versions = ">=3.10,<4.0"
""",
)
def test_get_crewai_version(self, mock_open):
def test_get_crewai_version(self):
from crewai.cli.utils import get_crewai_version
version = get_crewai_version()
self.assertEqual(version, "0.51.1")
@patch("builtins.open", side_effect=FileNotFoundError)
def test_get_crewai_version_file_not_found(self, mock_open):
from crewai.cli.utils import get_crewai_version
with patch("sys.stdout", new=StringIO()) as fake_out:
version = get_crewai_version()
self.assertEqual(version, "no-version-found")
self.assertIn("Error: poetry.lock not found.", fake_out.getvalue())
assert isinstance(get_crewai_version(), str)

View File

@@ -11,15 +11,22 @@ class TestPlusAPI(unittest.TestCase):
def test_init(self):
self.assertEqual(self.api.api_key, self.api_key)
self.assertEqual(
self.api.headers,
{
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
"User-Agent": "CrewAI-CLI/no-version-found",
"X-Crewai-Version": "no-version-found",
},
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
self.assertEqual(self.api.headers["Content-Type"], "application/json")
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
self.assertTrue(self.api.headers["X-Crewai-Version"])
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_login_to_tool_repository(self, mock_make_request):
mock_response = MagicMock()
mock_make_request.return_value = mock_response
response = self.api.login_to_tool_repository()
mock_make_request.assert_called_once_with(
"POST", "/crewai_plus/api/v1/tools/login"
)
self.assertEqual(response, mock_response)
@patch("crewai.cli.plus_api.PlusAPI._make_request")
def test_get_tool(self, mock_make_request):

View File

@@ -13,11 +13,7 @@ class TestToolCommand(unittest.TestCase):
mock_get_response.status_code = 200
mock_get_response.json.return_value = {
"handle": "sample-tool",
"repository": {
"handle": "sample-repo",
"url": "https://example.com/repo",
"credentials": "my_very_secret",
},
"repository": {"handle": "sample-repo", "url": "https://example.com/repo"},
}
mock_get.return_value = mock_get_response
mock_subprocess_run.return_value = MagicMock(stderr=None)
@@ -29,30 +25,6 @@ class TestToolCommand(unittest.TestCase):
output = fake_out.getvalue()
mock_get.assert_called_once_with("sample-tool")
mock_subprocess_run.assert_any_call(
[
"poetry",
"source",
"add",
"--priority=explicit",
"crewai-sample-repo",
"https://example.com/repo",
],
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
[
"poetry",
"config",
"http-basic.crewai-sample-repo",
"my_very_secret",
'""',
],
capture_output=False,
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
["poetry", "add", "--source", "crewai-sample-repo", "sample-tool"],
capture_output=False,
@@ -182,7 +154,7 @@ class TestToolCommand(unittest.TestCase):
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
self.assertIn("Failed to complete operation", output)
self.assertIn("Name is already taken", output)
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
@@ -210,9 +182,11 @@ class TestToolCommand(unittest.TestCase):
mock_get_project_version,
mock_get_project_name,
):
mock_publish_response = MagicMock()
mock_publish_response.status_code = 500
mock_publish.return_value = mock_publish_response
mock_response = MagicMock()
mock_response.status_code = 500
mock_response.json.return_value = {"error": "Internal Server Error"}
mock_response.ok = False
mock_publish.return_value = mock_response
tool_command = ToolCommand()
@@ -222,8 +196,55 @@ class TestToolCommand(unittest.TestCase):
output = fake_out.getvalue()
mock_publish.assert_called_once()
self.assertIn("Failed to publish tool", output)
self.assertIn("Request to Enterprise API failed", output)
@patch("crewai.cli.plus_api.PlusAPI.login_to_tool_repository")
@patch("crewai.cli.tools.main.subprocess.run")
def test_login_success(self, mock_subprocess_run, mock_login):
mock_login_response = MagicMock()
mock_login_response.status_code = 200
mock_login_response.json.return_value = {
"repositories": [
{
"handle": "tools",
"url": "https://example.com/repo",
}
],
"credential": {"username": "user", "password": "pass"},
}
mock_login.return_value = mock_login_response
if __name__ == "__main__":
unittest.main()
mock_subprocess_run.return_value = MagicMock(stderr=None)
tool_command = ToolCommand()
with patch("sys.stdout", new=StringIO()) as fake_out:
tool_command.login()
output = fake_out.getvalue()
mock_login.assert_called_once()
mock_subprocess_run.assert_any_call(
[
"poetry",
"source",
"add",
"--priority=explicit",
"crewai-tools",
"https://example.com/repo",
],
text=True,
check=True,
)
mock_subprocess_run.assert_any_call(
[
"poetry",
"config",
"http-basic.crewai-tools",
"user",
"pass",
],
capture_output=False,
text=True,
check=True,
)
self.assertIn("Succesfully authenticated to the tool repository", output)