diff --git a/src/crewai/cli/authentication/main.py b/src/crewai/cli/authentication/main.py index 543f06844..c64045801 100644 --- a/src/crewai/cli/authentication/main.py +++ b/src/crewai/cli/authentication/main.py @@ -7,6 +7,7 @@ from rich.console import Console from .constants import AUTH0_AUDIENCE, AUTH0_CLIENT_ID, AUTH0_DOMAIN from .utils import TokenManager, validate_token +from crewai.cli.tools.main import ToolCommand console = Console() @@ -63,7 +64,22 @@ class AuthenticationCommand: validate_token(token_data["id_token"]) expires_in = 360000 # Token expiration time in seconds self.token_manager.save_tokens(token_data["access_token"], expires_in) - console.print("\nWelcome to CrewAI+ !!", style="green") + + try: + ToolCommand().login() + except Exception: + console.print( + "\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.", + style="yellow", + ) + console.print( + "Other features will work normally, but you may experience limitations " + "with downloading and publishing tools." + "\nRun [bold]crewai login[/bold] to try logging in again.\n", + style="yellow", + ) + + console.print("\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n") return if token_data["error"] not in ("authorization_pending", "slow_down"): diff --git a/src/crewai/cli/authentication/token.py b/src/crewai/cli/authentication/token.py new file mode 100644 index 000000000..889fdc2eb --- /dev/null +++ b/src/crewai/cli/authentication/token.py @@ -0,0 +1,10 @@ +from .utils import TokenManager + +def get_auth_token() -> str: + """Get the authentication token.""" + access_token = TokenManager().get_token() + if not access_token: + raise Exception() + return access_token + + diff --git a/src/crewai/cli/command.py b/src/crewai/cli/command.py index f05fe237f..f2af92bf5 100644 --- a/src/crewai/cli/command.py +++ b/src/crewai/cli/command.py @@ -2,7 +2,7 @@ import requests from requests.exceptions import JSONDecodeError from rich.console import Console from crewai.cli.plus_api import PlusAPI -from crewai.cli.utils import get_auth_token +from crewai.cli.authentication.token import get_auth_token from crewai.telemetry.telemetry import Telemetry console = Console() diff --git a/src/crewai/cli/plus_api.py b/src/crewai/cli/plus_api.py index 04f6fb8ff..2fce0d6d8 100644 --- a/src/crewai/cli/plus_api.py +++ b/src/crewai/cli/plus_api.py @@ -1,7 +1,7 @@ from typing import Optional import requests from os import getenv -from crewai.cli.utils import get_crewai_version +from crewai.cli.version import get_crewai_version from urllib.parse import urljoin diff --git a/src/crewai/cli/run_crew.py b/src/crewai/cli/run_crew.py index 5450cf32b..95b560109 100644 --- a/src/crewai/cli/run_crew.py +++ b/src/crewai/cli/run_crew.py @@ -3,7 +3,8 @@ import subprocess import click from packaging import version -from crewai.cli.utils import get_crewai_version, read_toml +from crewai.cli.utils import read_toml +from crewai.cli.version import get_crewai_version def run_crew() -> None: diff --git a/src/crewai/cli/utils.py b/src/crewai/cli/utils.py index 25da9e31a..2daba4111 100644 --- a/src/crewai/cli/utils.py +++ b/src/crewai/cli/utils.py @@ -1,4 +1,3 @@ -import importlib.metadata import os import shutil import sys @@ -9,7 +8,6 @@ import click import tomli from rich.console import Console -from crewai.cli.authentication.utils import TokenManager from crewai.cli.constants import ENV_VARS if sys.version_info >= (3, 11): @@ -137,11 +135,6 @@ def _get_nested_value(data: Dict[str, Any], keys: List[str]) -> Any: return reduce(dict.__getitem__, keys, data) -def get_crewai_version() -> str: - """Get the version number of CrewAI running the CLI""" - return importlib.metadata.version("crewai") - - def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: """Fetch the environment variables from a .env file and return them as a dictionary.""" try: @@ -166,14 +159,6 @@ def fetch_and_json_env_file(env_file_path: str = ".env") -> dict: return {} -def get_auth_token() -> str: - """Get the authentication token.""" - access_token = TokenManager().get_token() - if not access_token: - raise Exception() - return access_token - - def tree_copy(source, destination): """Copies the entire directory structure from the source to the destination.""" for item in os.listdir(source): diff --git a/src/crewai/cli/version.py b/src/crewai/cli/version.py new file mode 100644 index 000000000..543be9c32 --- /dev/null +++ b/src/crewai/cli/version.py @@ -0,0 +1,6 @@ +import importlib.metadata + +def get_crewai_version() -> str: + """Get the version number of CrewAI running the CLI""" + return importlib.metadata.version("crewai") + diff --git a/tests/cli/authentication/test_auth_main.py b/tests/cli/authentication/test_auth_main.py index c56968aab..4466cc999 100644 --- a/tests/cli/authentication/test_auth_main.py +++ b/tests/cli/authentication/test_auth_main.py @@ -43,10 +43,11 @@ class TestAuthenticationCommand(unittest.TestCase): mock_print.assert_any_call("2. Enter the following code: ", "ABCDEF") mock_open.assert_called_once_with("https://example.com") + @patch("crewai.cli.authentication.main.ToolCommand") @patch("crewai.cli.authentication.main.requests.post") @patch("crewai.cli.authentication.main.validate_token") @patch("crewai.cli.authentication.main.console.print") - def test_poll_for_token_success(self, mock_print, mock_validate_token, mock_post): + def test_poll_for_token_success(self, mock_print, mock_validate_token, mock_post, mock_tool): mock_response = MagicMock() mock_response.status_code = 200 mock_response.json.return_value = { @@ -55,10 +56,13 @@ class TestAuthenticationCommand(unittest.TestCase): } mock_post.return_value = mock_response + mock_instance = mock_tool.return_value + mock_instance.login.return_value = None + self.auth_command._poll_for_token({"device_code": "123456"}) mock_validate_token.assert_called_once_with("TOKEN") - mock_print.assert_called_once_with("\nWelcome to CrewAI+ !!", style="green") + mock_print.assert_called_once_with("\n[bold green]Welcome to CrewAI Enterprise![/bold green]\n") @patch("crewai.cli.authentication.main.requests.post") @patch("crewai.cli.authentication.main.console.print") diff --git a/tests/cli/deploy/test_deploy_main.py b/tests/cli/deploy/test_deploy_main.py index 385dbb8a5..bf1198a0b 100644 --- a/tests/cli/deploy/test_deploy_main.py +++ b/tests/cli/deploy/test_deploy_main.py @@ -260,6 +260,6 @@ class TestDeployCommand(unittest.TestCase): self.assertEqual(project_name, "test_project") def test_get_crewai_version(self): - from crewai.cli.utils import get_crewai_version + from crewai.cli.version import get_crewai_version assert isinstance(get_crewai_version(), str)