mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
chore: remove auth0 and the need of typing the email on 'crewai login' (#3408)
* Remove the need of typing the email on 'crewai login' * Remove auth0 constants, update tests
This commit is contained in:
@@ -1,6 +1 @@
|
|||||||
ALGORITHMS = ["RS256"]
|
ALGORITHMS = ["RS256"]
|
||||||
|
|
||||||
#TODO: The AUTH0 constants should be removed after WorkOS migration is completed
|
|
||||||
AUTH0_DOMAIN = "crewai.us.auth0.com"
|
|
||||||
AUTH0_CLIENT_ID = "DEVC5Fw6NlRoSzmDCcOhVq85EfLBjKa8"
|
|
||||||
AUTH0_AUDIENCE = "https://crewai.us.auth0.com/api/v2/"
|
|
||||||
|
|||||||
@@ -9,14 +9,7 @@ from pydantic import BaseModel, Field
|
|||||||
|
|
||||||
from .utils import validate_jwt_token
|
from .utils import validate_jwt_token
|
||||||
from crewai.cli.shared.token_manager import TokenManager
|
from crewai.cli.shared.token_manager import TokenManager
|
||||||
from urllib.parse import quote
|
|
||||||
from crewai.cli.plus_api import PlusAPI
|
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.authentication.constants import (
|
|
||||||
AUTH0_AUDIENCE,
|
|
||||||
AUTH0_CLIENT_ID,
|
|
||||||
AUTH0_DOMAIN,
|
|
||||||
)
|
|
||||||
|
|
||||||
console = Console()
|
console = Console()
|
||||||
|
|
||||||
@@ -72,18 +65,6 @@ class AuthenticationCommand:
|
|||||||
"""Sign up to CrewAI+"""
|
"""Sign up to CrewAI+"""
|
||||||
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
console.print("Signing in to CrewAI Enterprise...\n", style="bold blue")
|
||||||
|
|
||||||
# TODO: WORKOS - Next line and conditional are temporary until migration to WorkOS is complete.
|
|
||||||
user_provider = self._determine_user_provider()
|
|
||||||
if user_provider == "auth0":
|
|
||||||
settings = Oauth2Settings(
|
|
||||||
provider="auth0",
|
|
||||||
client_id=AUTH0_CLIENT_ID,
|
|
||||||
domain=AUTH0_DOMAIN,
|
|
||||||
audience=AUTH0_AUDIENCE,
|
|
||||||
)
|
|
||||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
|
||||||
# End of temporary code.
|
|
||||||
|
|
||||||
device_code_data = self._get_device_code()
|
device_code_data = self._get_device_code()
|
||||||
self._display_auth_instructions(device_code_data)
|
self._display_auth_instructions(device_code_data)
|
||||||
|
|
||||||
@@ -206,30 +187,3 @@ class AuthenticationCommand:
|
|||||||
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
"\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||||
style="yellow",
|
style="yellow",
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: WORKOS - This method is temporary until migration to WorkOS is complete.
|
|
||||||
def _determine_user_provider(self) -> str:
|
|
||||||
"""Determine which provider to use for authentication."""
|
|
||||||
|
|
||||||
console.print(
|
|
||||||
"Enter your CrewAI Enterprise account email: ", style="bold blue", end=""
|
|
||||||
)
|
|
||||||
email = input()
|
|
||||||
email_encoded = quote(email)
|
|
||||||
|
|
||||||
# It's not correct to call this method directly, but it's temporary until migration is complete.
|
|
||||||
response = PlusAPI("")._make_request(
|
|
||||||
"GET", f"/crewai_plus/api/v1/me/provider?email={email_encoded}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if response.status_code == 200:
|
|
||||||
if response.json().get("provider") == "auth0":
|
|
||||||
return "auth0"
|
|
||||||
else:
|
|
||||||
return "workos"
|
|
||||||
else:
|
|
||||||
console.print(
|
|
||||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
|
||||||
style="red",
|
|
||||||
)
|
|
||||||
raise SystemExit
|
|
||||||
|
|||||||
@@ -3,11 +3,6 @@ from datetime import datetime, timedelta
|
|||||||
import requests
|
import requests
|
||||||
from unittest.mock import MagicMock, patch, call
|
from unittest.mock import MagicMock, patch, call
|
||||||
from crewai.cli.authentication.main import AuthenticationCommand
|
from crewai.cli.authentication.main import AuthenticationCommand
|
||||||
from crewai.cli.authentication.constants import (
|
|
||||||
AUTH0_AUDIENCE,
|
|
||||||
AUTH0_CLIENT_ID,
|
|
||||||
AUTH0_DOMAIN
|
|
||||||
)
|
|
||||||
from crewai.cli.constants import (
|
from crewai.cli.constants import (
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||||
@@ -22,16 +17,6 @@ class TestAuthenticationCommand:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"user_provider,expected_urls",
|
"user_provider,expected_urls",
|
||||||
[
|
[
|
||||||
(
|
|
||||||
"auth0",
|
|
||||||
{
|
|
||||||
"device_code_url": f"https://{AUTH0_DOMAIN}/oauth/device/code",
|
|
||||||
"token_url": f"https://{AUTH0_DOMAIN}/oauth/token",
|
|
||||||
"client_id": AUTH0_CLIENT_ID,
|
|
||||||
"audience": AUTH0_AUDIENCE,
|
|
||||||
"domain": AUTH0_DOMAIN,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"workos",
|
"workos",
|
||||||
{
|
{
|
||||||
@@ -44,9 +29,6 @@ class TestAuthenticationCommand:
|
|||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@patch(
|
|
||||||
"crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
|
||||||
)
|
|
||||||
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
|
@patch("crewai.cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||||
@patch(
|
@patch(
|
||||||
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
"crewai.cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||||
@@ -59,11 +41,9 @@ class TestAuthenticationCommand:
|
|||||||
mock_poll,
|
mock_poll,
|
||||||
mock_display,
|
mock_display,
|
||||||
mock_get_device,
|
mock_get_device,
|
||||||
mock_determine_provider,
|
|
||||||
user_provider,
|
user_provider,
|
||||||
expected_urls,
|
expected_urls,
|
||||||
):
|
):
|
||||||
mock_determine_provider.return_value = user_provider
|
|
||||||
mock_get_device.return_value = {
|
mock_get_device.return_value = {
|
||||||
"device_code": "test_code",
|
"device_code": "test_code",
|
||||||
"user_code": "123456",
|
"user_code": "123456",
|
||||||
@@ -74,7 +54,6 @@ class TestAuthenticationCommand:
|
|||||||
mock_console_print.assert_called_once_with(
|
mock_console_print.assert_called_once_with(
|
||||||
"Signing in to CrewAI Enterprise...\n", style="bold blue"
|
"Signing in to CrewAI Enterprise...\n", style="bold blue"
|
||||||
)
|
)
|
||||||
mock_determine_provider.assert_called_once()
|
|
||||||
mock_get_device.assert_called_once()
|
mock_get_device.assert_called_once()
|
||||||
mock_display.assert_called_once_with(
|
mock_display.assert_called_once_with(
|
||||||
{"device_code": "test_code", "user_code": "123456"}
|
{"device_code": "test_code", "user_code": "123456"}
|
||||||
@@ -82,9 +61,17 @@ class TestAuthenticationCommand:
|
|||||||
mock_poll.assert_called_once_with(
|
mock_poll.assert_called_once_with(
|
||||||
{"device_code": "test_code", "user_code": "123456"},
|
{"device_code": "test_code", "user_code": "123456"},
|
||||||
)
|
)
|
||||||
assert self.auth_command.oauth2_provider.get_client_id() == expected_urls["client_id"]
|
assert (
|
||||||
assert self.auth_command.oauth2_provider.get_audience() == expected_urls["audience"]
|
self.auth_command.oauth2_provider.get_client_id()
|
||||||
assert self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
== expected_urls["client_id"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.auth_command.oauth2_provider.get_audience()
|
||||||
|
== expected_urls["audience"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||||
|
)
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.main.webbrowser")
|
@patch("crewai.cli.authentication.main.webbrowser")
|
||||||
@patch("crewai.cli.authentication.main.console.print")
|
@patch("crewai.cli.authentication.main.console.print")
|
||||||
@@ -106,14 +93,6 @@ class TestAuthenticationCommand:
|
|||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"user_provider,jwt_config",
|
"user_provider,jwt_config",
|
||||||
[
|
[
|
||||||
(
|
|
||||||
"auth0",
|
|
||||||
{
|
|
||||||
"jwks_url": f"https://{AUTH0_DOMAIN}/.well-known/jwks.json",
|
|
||||||
"issuer": f"https://{AUTH0_DOMAIN}/",
|
|
||||||
"audience": AUTH0_AUDIENCE,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"workos",
|
"workos",
|
||||||
{
|
{
|
||||||
@@ -135,14 +114,18 @@ class TestAuthenticationCommand:
|
|||||||
jwt_config,
|
jwt_config,
|
||||||
has_expiration,
|
has_expiration,
|
||||||
):
|
):
|
||||||
from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
|
||||||
from crewai.cli.authentication.providers.workos import WorkosProvider
|
from crewai.cli.authentication.providers.workos import WorkosProvider
|
||||||
from crewai.cli.authentication.main import Oauth2Settings
|
from crewai.cli.authentication.main import Oauth2Settings
|
||||||
|
|
||||||
if user_provider == "auth0":
|
if user_provider == "workos":
|
||||||
self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=AUTH0_DOMAIN, audience=jwt_config["audience"]))
|
self.auth_command.oauth2_provider = WorkosProvider(
|
||||||
elif user_provider == "workos":
|
settings=Oauth2Settings(
|
||||||
self.auth_command.oauth2_provider = WorkosProvider(settings=Oauth2Settings(provider=user_provider, client_id="test-client-id", domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, audience=jwt_config["audience"]))
|
provider=user_provider,
|
||||||
|
client_id="test-client-id",
|
||||||
|
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||||
|
audience=jwt_config["audience"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||||
|
|
||||||
@@ -234,83 +217,6 @@ class TestAuthenticationCommand:
|
|||||||
]
|
]
|
||||||
mock_console_print.assert_has_calls(expected_calls)
|
mock_console_print.assert_has_calls(expected_calls)
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"api_response,expected_provider",
|
|
||||||
[
|
|
||||||
({"provider": "auth0"}, "auth0"),
|
|
||||||
({"provider": "workos"}, "workos"),
|
|
||||||
({"provider": "none"}, "workos"), # Default to workos for any other value
|
|
||||||
(
|
|
||||||
{},
|
|
||||||
"workos",
|
|
||||||
), # Default to workos if no provider key is sent in the response
|
|
||||||
],
|
|
||||||
)
|
|
||||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
|
||||||
@patch("crewai.cli.authentication.main.console.print")
|
|
||||||
@patch("builtins.input", return_value="test@example.com")
|
|
||||||
def test_determine_user_provider_success(
|
|
||||||
self,
|
|
||||||
mock_input,
|
|
||||||
mock_console_print,
|
|
||||||
mock_plus_api,
|
|
||||||
api_response,
|
|
||||||
expected_provider,
|
|
||||||
):
|
|
||||||
mock_api_instance = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json.return_value = api_response
|
|
||||||
mock_api_instance._make_request.return_value = mock_response
|
|
||||||
mock_plus_api.return_value = mock_api_instance
|
|
||||||
|
|
||||||
result = self.auth_command._determine_user_provider()
|
|
||||||
|
|
||||||
mock_input.assert_called_once()
|
|
||||||
|
|
||||||
mock_plus_api.assert_called_once_with("")
|
|
||||||
mock_api_instance._make_request.assert_called_once_with(
|
|
||||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
assert result == expected_provider
|
|
||||||
|
|
||||||
@patch("crewai.cli.authentication.main.PlusAPI")
|
|
||||||
@patch("crewai.cli.authentication.main.console.print")
|
|
||||||
@patch("builtins.input", return_value="test@example.com")
|
|
||||||
def test_determine_user_provider_error(
|
|
||||||
self, mock_input, mock_console_print, mock_plus_api
|
|
||||||
):
|
|
||||||
mock_api_instance = MagicMock()
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 500
|
|
||||||
mock_api_instance._make_request.return_value = mock_response
|
|
||||||
mock_plus_api.return_value = mock_api_instance
|
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
|
||||||
self.auth_command._determine_user_provider()
|
|
||||||
|
|
||||||
mock_input.assert_called_once()
|
|
||||||
|
|
||||||
mock_plus_api.assert_called_once_with("")
|
|
||||||
mock_api_instance._make_request.assert_called_once_with(
|
|
||||||
"GET", "/crewai_plus/api/v1/me/provider?email=test%40example.com"
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_console_print.assert_has_calls(
|
|
||||||
[
|
|
||||||
call(
|
|
||||||
"Enter your CrewAI Enterprise account email: ",
|
|
||||||
style="bold blue",
|
|
||||||
end="",
|
|
||||||
),
|
|
||||||
call(
|
|
||||||
"Error: Failed to authenticate with crewai enterprise. Ensure that you are using the latest crewai version and please try again. If the problem persists, contact support@crewai.com.",
|
|
||||||
style="red",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
@patch("requests.post")
|
@patch("requests.post")
|
||||||
def test_get_device_code(self, mock_post):
|
def test_get_device_code(self, mock_post):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
@@ -323,7 +229,9 @@ class TestAuthenticationCommand:
|
|||||||
|
|
||||||
self.auth_command.oauth2_provider = MagicMock()
|
self.auth_command.oauth2_provider = MagicMock()
|
||||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = "https://example.com/device"
|
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||||
|
"https://example.com/device"
|
||||||
|
)
|
||||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||||
|
|
||||||
result = self.auth_command._get_device_code()
|
result = self.auth_command._get_device_code()
|
||||||
@@ -366,12 +274,12 @@ class TestAuthenticationCommand:
|
|||||||
) as mock_tool_login,
|
) as mock_tool_login,
|
||||||
):
|
):
|
||||||
self.auth_command.oauth2_provider = MagicMock()
|
self.auth_command.oauth2_provider = MagicMock()
|
||||||
self.auth_command.oauth2_provider.get_token_url.return_value = "https://example.com/token"
|
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||||
|
"https://example.com/token"
|
||||||
|
)
|
||||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||||
|
|
||||||
self.auth_command._poll_for_token(
|
self.auth_command._poll_for_token(device_code_data)
|
||||||
device_code_data
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_post.assert_called_once_with(
|
mock_post.assert_called_once_with(
|
||||||
"https://example.com/token",
|
"https://example.com/token",
|
||||||
@@ -406,9 +314,7 @@ class TestAuthenticationCommand:
|
|||||||
"interval": 0.1, # Short interval for testing
|
"interval": 0.1, # Short interval for testing
|
||||||
}
|
}
|
||||||
|
|
||||||
self.auth_command._poll_for_token(
|
self.auth_command._poll_for_token(device_code_data)
|
||||||
device_code_data
|
|
||||||
)
|
|
||||||
|
|
||||||
mock_console_print.assert_any_call(
|
mock_console_print.assert_any_call(
|
||||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||||
@@ -429,15 +335,4 @@ class TestAuthenticationCommand:
|
|||||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||||
|
|
||||||
with pytest.raises(requests.HTTPError):
|
with pytest.raises(requests.HTTPError):
|
||||||
self.auth_command._poll_for_token(
|
self.auth_command._poll_for_token(device_code_data)
|
||||||
device_code_data
|
|
||||||
)
|
|
||||||
# @patch(
|
|
||||||
# "crewai.cli.authentication.main.AuthenticationCommand._determine_user_provider"
|
|
||||||
# )
|
|
||||||
# def test_login_with_auth0(self, mock_determine_provider):
|
|
||||||
# from crewai.cli.authentication.providers.auth0 import Auth0Provider
|
|
||||||
# from crewai.cli.authentication.main import Oauth2Settings
|
|
||||||
|
|
||||||
# self.auth_command.oauth2_provider = Auth0Provider(settings=Oauth2Settings(provider="auth0", client_id=AUTH0_CLIENT_ID, domain=AUTH0_DOMAIN, audience=AUTH0_AUDIENCE))
|
|
||||||
# self.auth_command.login()
|
|
||||||
|
|||||||
Reference in New Issue
Block a user