mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 23:58:34 +00:00
feat: add enterprise configure command (#3289)
* feat: add enterprise configure command * refactor: renaming EnterpriseCommand to EnterpriseConfigureCommand
This commit is contained in:
@@ -14,6 +14,7 @@ from crewai.memory.storage.kickoff_task_outputs_storage import (
|
||||
|
||||
from .authentication.main import AuthenticationCommand
|
||||
from .deploy.main import DeployCommand
|
||||
from .enterprise.main import EnterpriseConfigureCommand
|
||||
from .evaluate_crew import evaluate_crew
|
||||
from .install_crew import install_crew
|
||||
from .kickoff_flow import kickoff_flow
|
||||
@@ -392,6 +393,20 @@ def current():
|
||||
org_command.current()
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def enterprise():
|
||||
"""Enterprise Configuration commands."""
|
||||
pass
|
||||
|
||||
|
||||
@enterprise.command("configure")
|
||||
@click.argument("enterprise_url")
|
||||
def enterprise_configure(enterprise_url: str):
|
||||
"""Configure CrewAI Enterprise OAuth2 settings from the provided Enterprise URL."""
|
||||
enterprise_command = EnterpriseConfigureCommand()
|
||||
enterprise_command.configure(enterprise_url)
|
||||
|
||||
|
||||
@crewai.group()
|
||||
def config():
|
||||
"""CLI Configuration commands."""
|
||||
|
||||
0
src/crewai/cli/enterprise/__init__.py
Normal file
0
src/crewai/cli/enterprise/__init__.py
Normal file
84
src/crewai/cli/enterprise/main.py
Normal file
84
src/crewai/cli/enterprise/main.py
Normal file
@@ -0,0 +1,84 @@
|
||||
import requests
|
||||
from typing import Dict, Any
|
||||
from rich.console import Console
|
||||
from requests.exceptions import RequestException, JSONDecodeError
|
||||
|
||||
from crewai.cli.command import BaseCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
console = Console()
|
||||
|
||||
|
||||
class EnterpriseConfigureCommand(BaseCommand):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.settings_command = SettingsCommand()
|
||||
|
||||
def configure(self, enterprise_url: str) -> None:
|
||||
try:
|
||||
enterprise_url = enterprise_url.rstrip('/')
|
||||
|
||||
oauth_config = self._fetch_oauth_config(enterprise_url)
|
||||
|
||||
self._update_oauth_settings(enterprise_url, oauth_config)
|
||||
|
||||
console.print(
|
||||
f"✅ Successfully configured CrewAI Enterprise with OAuth2 settings from {enterprise_url}",
|
||||
style="bold green"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
console.print(f"❌ Failed to configure Enterprise settings: {str(e)}", style="bold red")
|
||||
raise SystemExit(1)
|
||||
|
||||
def _fetch_oauth_config(self, enterprise_url: str) -> Dict[str, Any]:
|
||||
oauth_endpoint = f"{enterprise_url}/oauth/parameters"
|
||||
|
||||
try:
|
||||
console.print(f"🔄 Fetching OAuth2 configuration from {oauth_endpoint}...")
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": f"CrewAI-CLI/{get_crewai_version()}",
|
||||
"X-Crewai-Version": get_crewai_version(),
|
||||
}
|
||||
response = requests.get(oauth_endpoint, timeout=30, headers=headers)
|
||||
response.raise_for_status()
|
||||
|
||||
try:
|
||||
oauth_config = response.json()
|
||||
except JSONDecodeError:
|
||||
raise ValueError(f"Invalid JSON response from {oauth_endpoint}")
|
||||
|
||||
required_fields = ['audience', 'domain', 'device_authorization_client_id', 'provider']
|
||||
missing_fields = [field for field in required_fields if field not in oauth_config]
|
||||
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing required fields in OAuth2 configuration: {', '.join(missing_fields)}")
|
||||
|
||||
console.print("✅ Successfully retrieved OAuth2 configuration", style="green")
|
||||
return oauth_config
|
||||
|
||||
except RequestException as e:
|
||||
raise ValueError(f"Failed to connect to enterprise URL: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Error fetching OAuth2 configuration: {str(e)}")
|
||||
|
||||
def _update_oauth_settings(self, enterprise_url: str, oauth_config: Dict[str, Any]) -> None:
|
||||
try:
|
||||
config_mapping = {
|
||||
'enterprise_base_url': enterprise_url,
|
||||
'oauth2_provider': oauth_config['provider'],
|
||||
'oauth2_audience': oauth_config['audience'],
|
||||
'oauth2_client_id': oauth_config['device_authorization_client_id'],
|
||||
'oauth2_domain': oauth_config['domain']
|
||||
}
|
||||
|
||||
console.print("🔄 Updating local OAuth2 configuration...")
|
||||
|
||||
for key, value in config_mapping.items():
|
||||
self.settings_command.set(key, value)
|
||||
console.print(f" ✓ Set {key}: {value}", style="dim")
|
||||
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to update OAuth2 settings: {str(e)}")
|
||||
0
tests/cli/enterprise/__init__.py
Normal file
0
tests/cli/enterprise/__init__.py
Normal file
151
tests/cli/enterprise/test_main.py
Normal file
151
tests/cli/enterprise/test_main.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import requests
|
||||
from requests.exceptions import JSONDecodeError
|
||||
|
||||
from crewai.cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai.cli.settings.main import SettingsCommand
|
||||
import shutil
|
||||
|
||||
|
||||
class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
|
||||
with patch('crewai.cli.enterprise.main.SettingsCommand') as mock_settings_command_class:
|
||||
self.mock_settings_command = Mock(spec=SettingsCommand)
|
||||
mock_settings_command_class.return_value = self.mock_settings_command
|
||||
|
||||
self.enterprise_command = EnterpriseConfigureCommand()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_successful_configuration(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
enterprise_url = "https://enterprise.example.com"
|
||||
self.enterprise_command.configure(enterprise_url)
|
||||
|
||||
expected_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CrewAI-CLI/1.0.0",
|
||||
"X-Crewai-Version": "1.0.0",
|
||||
}
|
||||
mock_requests_get.assert_called_once_with(
|
||||
"https://enterprise.example.com/oauth/parameters",
|
||||
timeout=30,
|
||||
headers=expected_headers
|
||||
)
|
||||
|
||||
expected_calls = [
|
||||
('enterprise_base_url', 'https://enterprise.example.com'),
|
||||
('oauth2_provider', 'workos'),
|
||||
('oauth2_audience', 'test_audience'),
|
||||
('oauth2_client_id', 'test_client_id'),
|
||||
('oauth2_domain', 'test.domain.com')
|
||||
]
|
||||
|
||||
actual_calls = self.mock_settings_command.set.call_args_list
|
||||
self.assertEqual(len(actual_calls), 5)
|
||||
|
||||
for i, (key, value) in enumerate(expected_calls):
|
||||
call_args = actual_calls[i][0]
|
||||
self.assertEqual(call_args[0], key)
|
||||
self.assertEqual(call_args[1], value)
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_http_error_handling(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = requests.HTTPError("404 Not Found")
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.side_effect = JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai.cli.enterprise.main.requests.get')
|
||||
@patch('crewai.cli.enterprise.main.get_crewai_version')
|
||||
def test_settings_update_error(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
self.mock_settings_command.set.side_effect = Exception("Settings update failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
def test_url_trailing_slash_removal(self):
|
||||
with patch.object(self.enterprise_command, '_fetch_oauth_config') as mock_fetch, \
|
||||
patch.object(self.enterprise_command, '_update_oauth_settings') as mock_update:
|
||||
|
||||
mock_fetch.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
}
|
||||
|
||||
self.enterprise_command.configure("https://enterprise.example.com/")
|
||||
|
||||
mock_fetch.assert_called_once_with("https://enterprise.example.com")
|
||||
mock_update.assert_called_once()
|
||||
Reference in New Issue
Block a user