feat: add enterprise configure command (#3289)

* feat: add enterprise configure command

* refactor: renaming EnterpriseCommand to EnterpriseConfigureCommand
This commit is contained in:
Lucas Gomide
2025-08-08 09:50:01 -03:00
committed by GitHub
parent 915857541e
commit f9481cf10d
5 changed files with 250 additions and 0 deletions

View File

@@ -14,6 +14,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
from .authentication.main import AuthenticationCommand
from .deploy.main import DeployCommand
from .enterprise.main import EnterpriseConfigureCommand
from .evaluate_crew import evaluate_crew
from .install_crew import install_crew
from .kickoff_flow import kickoff_flow
@@ -392,6 +393,20 @@ def current():
org_command.current()
@crewai.group()
def enterprise():
"""Enterprise Configuration commands."""
pass
@enterprise.command("configure")
@click.argument("enterprise_url")
def enterprise_configure(enterprise_url: str):
"""Configure CrewAI Enterprise OAuth2 settings from the provided Enterprise URL."""
enterprise_command = EnterpriseConfigureCommand()
enterprise_command.configure(enterprise_url)
@crewai.group()
def config():
"""CLI Configuration commands."""

View File

View File

@@ -0,0 +1,84 @@
import requests
from typing import Dict, Any
from rich.console import Console
from requests.exceptions import RequestException, JSONDecodeError
from crewai.cli.command import BaseCommand
from crewai.cli.settings.main import SettingsCommand
from crewai.cli.version import get_crewai_version
console = Console()
class EnterpriseConfigureCommand(BaseCommand):
def __init__(self):
super().__init__()
self.settings_command = SettingsCommand()
def configure(self, enterprise_url: str) -> None:
try:
enterprise_url = enterprise_url.rstrip('/')
oauth_config = self._fetch_oauth_config(enterprise_url)
self._update_oauth_settings(enterprise_url, oauth_config)
console.print(
f"✅ Successfully configured CrewAI Enterprise with OAuth2 settings from {enterprise_url}",
style="bold green"
)
except Exception as e:
console.print(f"❌ Failed to configure Enterprise settings: {str(e)}", style="bold red")
raise SystemExit(1)
def _fetch_oauth_config(self, enterprise_url: str) -> Dict[str, Any]:
oauth_endpoint = f"{enterprise_url}/oauth/parameters"
try:
console.print(f"🔄 Fetching OAuth2 configuration from {oauth_endpoint}...")
headers = {
"Content-Type": "application/json",
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
"X-Crewai-Version": get_crewai_version(),
}
response = requests.get(oauth_endpoint, timeout=30, headers=headers)
response.raise_for_status()
try:
oauth_config = response.json()
except JSONDecodeError:
raise ValueError(f"Invalid JSON response from {oauth_endpoint}")
required_fields = ['audience', 'domain', 'device_authorization_client_id', 'provider']
missing_fields = [field for field in required_fields if field not in oauth_config]
if missing_fields:
raise ValueError(f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}")
console.print("✅ Successfully retrieved OAuth2 configuration", style="green")
return oauth_config
except RequestException as e:
raise ValueError(f"Failed to connect to enterprise URL: {str(e)}")
except Exception as e:
raise ValueError(f"Error fetching OAuth2 configuration: {str(e)}")
def _update_oauth_settings(self, enterprise_url: str, oauth_config: Dict[str, Any]) -> None:
try:
config_mapping = {
'enterprise_base_url': enterprise_url,
'oauth2_provider': oauth_config['provider'],
'oauth2_audience': oauth_config['audience'],
'oauth2_client_id': oauth_config['device_authorization_client_id'],
'oauth2_domain': oauth_config['domain']
}
console.print("🔄 Updating local OAuth2 configuration...")
for key, value in config_mapping.items():
self.settings_command.set(key, value)
console.print(f" ✓ Set {key}: {value}", style="dim")
except Exception as e:
raise ValueError(f"Failed to update OAuth2 settings: {str(e)}")

View File

View File

@@ -0,0 +1,151 @@
import tempfile
import unittest
from pathlib import Path
from unittest.mock import Mock, patch
import requests
from requests.exceptions import JSONDecodeError
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
from crewai.cli.settings.main import SettingsCommand
import shutil
class TestEnterpriseConfigureCommand(unittest.TestCase):
def setUp(self):
self.test_dir = Path(tempfile.mkdtemp())
self.config_path = self.test_dir / "settings.json"
with patch('crewai.cli.enterprise.main.SettingsCommand') as mock_settings_command_class:
self.mock_settings_command = Mock(spec=SettingsCommand)
mock_settings_command_class.return_value = self.mock_settings_command
self.enterprise_command = EnterpriseConfigureCommand()
def tearDown(self):
shutil.rmtree(self.test_dir)
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_successful_configuration(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'audience': 'test_audience',
'domain': 'test.domain.com',
'device_authorization_client_id': 'test_client_id',
'provider': 'workos'
}
mock_requests_get.return_value = mock_response
enterprise_url = "https://enterprise.example.com"
self.enterprise_command.configure(enterprise_url)
expected_headers = {
"Content-Type": "application/json",
"User-Agent": "CrewAI-CLI/1.0.0",
"X-Crewai-Version": "1.0.0",
}
mock_requests_get.assert_called_once_with(
"https://enterprise.example.com/oauth/parameters",
timeout=30,
headers=expected_headers
)
expected_calls = [
('enterprise_base_url', 'https://enterprise.example.com'),
('oauth2_provider', 'workos'),
('oauth2_audience', 'test_audience'),
('oauth2_client_id', 'test_client_id'),
('oauth2_domain', 'test.domain.com')
]
actual_calls = self.mock_settings_command.set.call_args_list
self.assertEqual(len(actual_calls), 5)
for i, (key, value) in enumerate(expected_calls):
call_args = actual_calls[i][0]
self.assertEqual(call_args[0], key)
self.assertEqual(call_args[1], value)
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_http_error_handling(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.side_effect = JSONDecodeError("Invalid JSON", "", 0)
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'audience': 'test_audience',
}
mock_requests_get.return_value = mock_response
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
@patch('crewai.cli.enterprise.main.requests.get')
@patch('crewai.cli.enterprise.main.get_crewai_version')
def test_settings_update_error(self, mock_get_version, mock_requests_get):
mock_get_version.return_value = "1.0.0"
mock_response = Mock()
mock_response.status_code = 200
mock_response.raise_for_status.return_value = None
mock_response.json.return_value = {
'audience': 'test_audience',
'domain': 'test.domain.com',
'device_authorization_client_id': 'test_client_id',
'provider': 'workos'
}
mock_requests_get.return_value = mock_response
self.mock_settings_command.set.side_effect = Exception("Settings update failed")
with self.assertRaises(SystemExit):
self.enterprise_command.configure("https://enterprise.example.com")
def test_url_trailing_slash_removal(self):
with patch.object(self.enterprise_command, '_fetch_oauth_config') as mock_fetch, \
patch.object(self.enterprise_command, '_update_oauth_settings') as mock_update:
mock_fetch.return_value = {
'audience': 'test_audience',
'domain': 'test.domain.com',
'device_authorization_client_id': 'test_client_id',
'provider': 'workos'
}
self.enterprise_command.configure("https://enterprise.example.com/")
mock_fetch.assert_called_once_with("https://enterprise.example.com")
mock_update.assert_called_once()