mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-24 15:48:23 +00:00
feat: reset tokens on crewai config reset (#3365)
This commit is contained in:
@@ -7,7 +7,8 @@ from rich.console import Console
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from .utils import TokenManager, validate_jwt_token
|
||||
from .utils import validate_jwt_token
|
||||
from crewai.cli.shared.token_manager import TokenManager
|
||||
from urllib.parse import quote
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
from crewai.cli.config import Settings
|
||||
@@ -21,10 +22,19 @@ console = Console()
|
||||
|
||||
|
||||
class Oauth2Settings(BaseModel):
|
||||
provider: str = Field(description="OAuth2 provider used for authentication (e.g., workos, okta, auth0).")
|
||||
client_id: str = Field(description="OAuth2 client ID issued by the provider, used during authentication requests.")
|
||||
domain: str = Field(description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens.")
|
||||
audience: Optional[str] = Field(description="OAuth2 audience value, typically used to identify the target API or resource.", default=None)
|
||||
provider: str = Field(
|
||||
description="OAuth2 provider used for authentication (e.g., workos, okta, auth0)."
|
||||
)
|
||||
client_id: str = Field(
|
||||
description="OAuth2 client ID issued by the provider, used during authentication requests."
|
||||
)
|
||||
domain: str = Field(
|
||||
description="OAuth2 provider's domain (e.g., your-org.auth0.com) used for issuing tokens."
|
||||
)
|
||||
audience: Optional[str] = Field(
|
||||
description="OAuth2 audience value, typically used to identify the target API or resource.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_settings(cls):
|
||||
@@ -44,11 +54,15 @@ class ProviderFactory:
|
||||
settings = settings or Oauth2Settings.from_settings()
|
||||
|
||||
import importlib
|
||||
module = importlib.import_module(f"crewai.cli.authentication.providers.{settings.provider.lower()}")
|
||||
|
||||
module = importlib.import_module(
|
||||
f"crewai.cli.authentication.providers.{settings.provider.lower()}"
|
||||
)
|
||||
provider = getattr(module, f"{settings.provider.capitalize()}Provider")
|
||||
|
||||
return provider(settings)
|
||||
|
||||
|
||||
class AuthenticationCommand:
|
||||
def __init__(self):
|
||||
self.token_manager = TokenManager()
|
||||
@@ -65,7 +79,7 @@ class AuthenticationCommand:
|
||||
provider="auth0",
|
||||
client_id=AUTH0_CLIENT_ID,
|
||||
domain=AUTH0_DOMAIN,
|
||||
audience=AUTH0_AUDIENCE
|
||||
audience=AUTH0_AUDIENCE,
|
||||
)
|
||||
self.oauth2_provider = ProviderFactory.from_settings(settings)
|
||||
# End of temporary code.
|
||||
@@ -75,9 +89,7 @@ class AuthenticationCommand:
|
||||
|
||||
return self._poll_for_token(device_code_data)
|
||||
|
||||
def _get_device_code(
|
||||
self
|
||||
) -> Dict[str, Any]:
|
||||
def _get_device_code(self) -> Dict[str, Any]:
|
||||
"""Get the device code to authenticate the user."""
|
||||
|
||||
device_code_payload = {
|
||||
@@ -86,7 +98,9 @@ class AuthenticationCommand:
|
||||
"audience": self.oauth2_provider.get_audience(),
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.oauth2_provider.get_authorize_url(), data=device_code_payload, timeout=20
|
||||
url=self.oauth2_provider.get_authorize_url(),
|
||||
data=device_code_payload,
|
||||
timeout=20,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -97,9 +111,7 @@ class AuthenticationCommand:
|
||||
console.print("2. Enter the following code: ", device_code_data["user_code"])
|
||||
webbrowser.open(device_code_data["verification_uri_complete"])
|
||||
|
||||
def _poll_for_token(
|
||||
self, device_code_data: Dict[str, Any]
|
||||
) -> None:
|
||||
def _poll_for_token(self, device_code_data: Dict[str, Any]) -> None:
|
||||
"""Polls the server for the token until it is received, or max attempts are reached."""
|
||||
|
||||
token_payload = {
|
||||
@@ -112,7 +124,9 @@ class AuthenticationCommand:
|
||||
|
||||
attempts = 0
|
||||
while True and attempts < 10:
|
||||
response = requests.post(self.oauth2_provider.get_token_url(), data=token_payload, timeout=30)
|
||||
response = requests.post(
|
||||
self.oauth2_provider.get_token_url(), data=token_payload, timeout=30
|
||||
)
|
||||
token_data = response.json()
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
Reference in New Issue
Block a user