mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
feat: add crewai config command group and tests (#3206)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import Optional
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
|
from crewai.cli.settings.main import SettingsCommand
|
||||||
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
from crewai.cli.add_crew_to_flow import add_crew_to_flow
|
||||||
from crewai.cli.create_crew import create_crew
|
from crewai.cli.create_crew import create_crew
|
||||||
from crewai.cli.create_flow import create_flow
|
from crewai.cli.create_flow import create_flow
|
||||||
@@ -227,7 +228,7 @@ def update():
|
|||||||
@crewai.command()
|
@crewai.command()
|
||||||
def login():
|
def login():
|
||||||
"""Sign Up/Login to CrewAI Enterprise."""
|
"""Sign Up/Login to CrewAI Enterprise."""
|
||||||
Settings().clear()
|
Settings().clear_user_settings()
|
||||||
AuthenticationCommand().login()
|
AuthenticationCommand().login()
|
||||||
|
|
||||||
|
|
||||||
@@ -369,8 +370,8 @@ def org():
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
@org.command()
|
@org.command("list")
|
||||||
def list():
|
def org_list():
|
||||||
"""List available organizations."""
|
"""List available organizations."""
|
||||||
org_command = OrganizationCommand()
|
org_command = OrganizationCommand()
|
||||||
org_command.list()
|
org_command.list()
|
||||||
@@ -391,5 +392,34 @@ def current():
|
|||||||
org_command.current()
|
org_command.current()
|
||||||
|
|
||||||
|
|
||||||
|
@crewai.group()
|
||||||
|
def config():
|
||||||
|
"""CLI Configuration commands."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@config.command("list")
|
||||||
|
def config_list():
|
||||||
|
"""List all CLI configuration parameters."""
|
||||||
|
config_command = SettingsCommand()
|
||||||
|
config_command.list()
|
||||||
|
|
||||||
|
|
||||||
|
@config.command("set")
|
||||||
|
@click.argument("key")
|
||||||
|
@click.argument("value")
|
||||||
|
def config_set(key: str, value: str):
|
||||||
|
"""Set a CLI configuration parameter."""
|
||||||
|
config_command = SettingsCommand()
|
||||||
|
config_command.set(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
@config.command("reset")
|
||||||
|
def config_reset():
|
||||||
|
"""Reset all CLI configuration parameters to default values."""
|
||||||
|
config_command = SettingsCommand()
|
||||||
|
config_command.reset_all_settings()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
crewai()
|
crewai()
|
||||||
|
|||||||
@@ -4,10 +4,47 @@ from typing import Optional
|
|||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
|
||||||
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
DEFAULT_CONFIG_PATH = Path.home() / ".config" / "crewai" / "settings.json"
|
||||||
|
|
||||||
|
# Settings that are related to the user's account
|
||||||
|
USER_SETTINGS_KEYS = [
|
||||||
|
"tool_repository_username",
|
||||||
|
"tool_repository_password",
|
||||||
|
"org_name",
|
||||||
|
"org_uuid",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Settings that are related to the CLI
|
||||||
|
CLI_SETTINGS_KEYS = [
|
||||||
|
"enterprise_base_url",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Default values for CLI settings
|
||||||
|
DEFAULT_CLI_SETTINGS = {
|
||||||
|
"enterprise_base_url": DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Readonly settings - cannot be set by the user
|
||||||
|
READONLY_SETTINGS_KEYS = [
|
||||||
|
"org_name",
|
||||||
|
"org_uuid",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Hidden settings - not displayed by the 'list' command and cannot be set by the user
|
||||||
|
HIDDEN_SETTINGS_KEYS = [
|
||||||
|
"config_path",
|
||||||
|
"tool_repository_username",
|
||||||
|
"tool_repository_password",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseModel):
|
class Settings(BaseModel):
|
||||||
|
enterprise_base_url: Optional[str] = Field(
|
||||||
|
default=DEFAULT_CREWAI_ENTERPRISE_URL,
|
||||||
|
description="Base URL of the CrewAI Enterprise instance",
|
||||||
|
)
|
||||||
tool_repository_username: Optional[str] = Field(
|
tool_repository_username: Optional[str] = Field(
|
||||||
None, description="Username for interacting with the Tool Repository"
|
None, description="Username for interacting with the Tool Repository"
|
||||||
)
|
)
|
||||||
@@ -20,7 +57,7 @@ class Settings(BaseModel):
|
|||||||
org_uuid: Optional[str] = Field(
|
org_uuid: Optional[str] = Field(
|
||||||
None, description="UUID of the currently active organization"
|
None, description="UUID of the currently active organization"
|
||||||
)
|
)
|
||||||
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, exclude=True)
|
config_path: Path = Field(default=DEFAULT_CONFIG_PATH, frozen=True, exclude=True)
|
||||||
|
|
||||||
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
def __init__(self, config_path: Path = DEFAULT_CONFIG_PATH, **data):
|
||||||
"""Load Settings from config path"""
|
"""Load Settings from config path"""
|
||||||
@@ -37,9 +74,16 @@ class Settings(BaseModel):
|
|||||||
merged_data = {**file_data, **data}
|
merged_data = {**file_data, **data}
|
||||||
super().__init__(config_path=config_path, **merged_data)
|
super().__init__(config_path=config_path, **merged_data)
|
||||||
|
|
||||||
def clear(self) -> None:
|
def clear_user_settings(self) -> None:
|
||||||
"""Clear all settings"""
|
"""Clear all user settings"""
|
||||||
self.config_path.unlink(missing_ok=True)
|
self._reset_user_settings()
|
||||||
|
self.dump()
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all settings to default values"""
|
||||||
|
self._reset_user_settings()
|
||||||
|
self._reset_cli_settings()
|
||||||
|
self.dump()
|
||||||
|
|
||||||
def dump(self) -> None:
|
def dump(self) -> None:
|
||||||
"""Save current settings to settings.json"""
|
"""Save current settings to settings.json"""
|
||||||
@@ -52,3 +96,13 @@ class Settings(BaseModel):
|
|||||||
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
updated_data = {**existing_data, **self.model_dump(exclude_unset=True)}
|
||||||
with self.config_path.open("w") as f:
|
with self.config_path.open("w") as f:
|
||||||
json.dump(updated_data, f, indent=4)
|
json.dump(updated_data, f, indent=4)
|
||||||
|
|
||||||
|
def _reset_user_settings(self) -> None:
|
||||||
|
"""Reset all user settings to default values"""
|
||||||
|
for key in USER_SETTINGS_KEYS:
|
||||||
|
setattr(self, key, None)
|
||||||
|
|
||||||
|
def _reset_cli_settings(self) -> None:
|
||||||
|
"""Reset all CLI settings to default values"""
|
||||||
|
for key in CLI_SETTINGS_KEYS:
|
||||||
|
setattr(self, key, DEFAULT_CLI_SETTINGS[key])
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
DEFAULT_CREWAI_ENTERPRISE_URL = "https://app.crewai.com"
|
||||||
|
|
||||||
ENV_VARS = {
|
ENV_VARS = {
|
||||||
"openai": [
|
"openai": [
|
||||||
{
|
{
|
||||||
@@ -320,5 +322,4 @@ DEFAULT_LLM_MODEL = "gpt-4o-mini"
|
|||||||
|
|
||||||
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
JSON_URL = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||||
|
|
||||||
|
|
||||||
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
LITELLM_PARAMS = ["api_key", "api_base", "api_version"]
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
from os import getenv
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
from urllib.parse import urljoin
|
from urllib.parse import urljoin
|
||||||
|
|
||||||
@@ -6,6 +5,7 @@ import requests
|
|||||||
|
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import Settings
|
||||||
from crewai.cli.version import get_crewai_version
|
from crewai.cli.version import get_crewai_version
|
||||||
|
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
|
||||||
|
|
||||||
class PlusAPI:
|
class PlusAPI:
|
||||||
@@ -29,7 +29,10 @@ class PlusAPI:
|
|||||||
settings = Settings()
|
settings = Settings()
|
||||||
if settings.org_uuid:
|
if settings.org_uuid:
|
||||||
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
|
self.headers["X-Crewai-Organization-Id"] = settings.org_uuid
|
||||||
self.base_url = getenv("CREWAI_BASE_URL", "https://app.crewai.com")
|
|
||||||
|
self.base_url = (
|
||||||
|
str(settings.enterprise_base_url) or DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
)
|
||||||
|
|
||||||
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
def _make_request(self, method: str, endpoint: str, **kwargs) -> requests.Response:
|
||||||
url = urljoin(self.base_url, endpoint)
|
url = urljoin(self.base_url, endpoint)
|
||||||
@@ -108,7 +111,6 @@ class PlusAPI:
|
|||||||
|
|
||||||
def create_crew(self, payload) -> requests.Response:
|
def create_crew(self, payload) -> requests.Response:
|
||||||
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
return self._make_request("POST", self.CREWS_RESOURCE, json=payload)
|
||||||
|
|
||||||
def get_organizations(self) -> requests.Response:
|
def get_organizations(self) -> requests.Response:
|
||||||
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
return self._make_request("GET", self.ORGANIZATIONS_RESOURCE)
|
||||||
|
|
||||||
0
src/crewai/cli/settings/__init__.py
Normal file
0
src/crewai/cli/settings/__init__.py
Normal file
67
src/crewai/cli/settings/main.py
Normal file
67
src/crewai/cli/settings/main.py
Normal file
@@ -0,0 +1,67 @@
|
|||||||
|
from rich.console import Console
|
||||||
|
from rich.table import Table
|
||||||
|
from crewai.cli.command import BaseCommand
|
||||||
|
from crewai.cli.config import Settings, READONLY_SETTINGS_KEYS, HIDDEN_SETTINGS_KEYS
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
console = Console()
|
||||||
|
|
||||||
|
|
||||||
|
class SettingsCommand(BaseCommand):
|
||||||
|
"""A class to handle CLI configuration commands."""
|
||||||
|
|
||||||
|
def __init__(self, settings_kwargs: dict[str, Any] = {}):
|
||||||
|
super().__init__()
|
||||||
|
self.settings = Settings(**settings_kwargs)
|
||||||
|
|
||||||
|
def list(self) -> None:
|
||||||
|
"""List all CLI configuration parameters."""
|
||||||
|
table = Table(title="CrewAI CLI Configuration")
|
||||||
|
table.add_column("Setting", style="cyan", no_wrap=True)
|
||||||
|
table.add_column("Value", style="green")
|
||||||
|
table.add_column("Description", style="yellow")
|
||||||
|
|
||||||
|
# Add all settings to the table
|
||||||
|
for field_name, field_info in Settings.model_fields.items():
|
||||||
|
if field_name in HIDDEN_SETTINGS_KEYS:
|
||||||
|
# Do not display hidden settings
|
||||||
|
continue
|
||||||
|
|
||||||
|
current_value = getattr(self.settings, field_name)
|
||||||
|
description = field_info.description or "No description available"
|
||||||
|
display_value = (
|
||||||
|
str(current_value) if current_value is not None else "Not set"
|
||||||
|
)
|
||||||
|
|
||||||
|
table.add_row(field_name, display_value, description)
|
||||||
|
|
||||||
|
console.print(table)
|
||||||
|
|
||||||
|
def set(self, key: str, value: str) -> None:
|
||||||
|
"""Set a CLI configuration parameter."""
|
||||||
|
|
||||||
|
readonly_settings = READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||||
|
|
||||||
|
if not hasattr(self.settings, key) or key in readonly_settings:
|
||||||
|
console.print(
|
||||||
|
f"Error: Unknown or readonly configuration key '{key}'",
|
||||||
|
style="bold red",
|
||||||
|
)
|
||||||
|
console.print("Available keys:", style="yellow")
|
||||||
|
for field_name in Settings.model_fields.keys():
|
||||||
|
if field_name not in readonly_settings:
|
||||||
|
console.print(f" - {field_name}", style="yellow")
|
||||||
|
raise SystemExit(1)
|
||||||
|
|
||||||
|
setattr(self.settings, key, value)
|
||||||
|
self.settings.dump()
|
||||||
|
|
||||||
|
console.print(f"Successfully set '{key}' to '{value}'", style="bold green")
|
||||||
|
|
||||||
|
def reset_all_settings(self) -> None:
|
||||||
|
"""Reset all CLI configuration parameters to default values."""
|
||||||
|
self.settings.reset()
|
||||||
|
console.print(
|
||||||
|
"Successfully reset all configuration parameters to default values. It is recommended to run [bold yellow]'crewai login'[/bold yellow] to re-authenticate.",
|
||||||
|
style="bold green",
|
||||||
|
)
|
||||||
@@ -4,7 +4,12 @@ import tempfile
|
|||||||
import unittest
|
import unittest
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from crewai.cli.config import Settings
|
from crewai.cli.config import (
|
||||||
|
Settings,
|
||||||
|
USER_SETTINGS_KEYS,
|
||||||
|
CLI_SETTINGS_KEYS,
|
||||||
|
DEFAULT_CLI_SETTINGS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestSettings(unittest.TestCase):
|
class TestSettings(unittest.TestCase):
|
||||||
@@ -52,6 +57,30 @@ class TestSettings(unittest.TestCase):
|
|||||||
self.assertEqual(settings.tool_repository_username, "new_user")
|
self.assertEqual(settings.tool_repository_username, "new_user")
|
||||||
self.assertEqual(settings.tool_repository_password, "file_pass")
|
self.assertEqual(settings.tool_repository_password, "file_pass")
|
||||||
|
|
||||||
|
def test_clear_user_settings(self):
|
||||||
|
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||||
|
|
||||||
|
settings = Settings(config_path=self.config_path, **user_settings)
|
||||||
|
settings.clear_user_settings()
|
||||||
|
|
||||||
|
for key in user_settings.keys():
|
||||||
|
self.assertEqual(getattr(settings, key), None)
|
||||||
|
|
||||||
|
def test_reset_settings(self):
|
||||||
|
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||||
|
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS}
|
||||||
|
|
||||||
|
settings = Settings(
|
||||||
|
config_path=self.config_path, **user_settings, **cli_settings
|
||||||
|
)
|
||||||
|
|
||||||
|
settings.reset()
|
||||||
|
|
||||||
|
for key in user_settings.keys():
|
||||||
|
self.assertEqual(getattr(settings, key), None)
|
||||||
|
for key in cli_settings.keys():
|
||||||
|
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS[key])
|
||||||
|
|
||||||
def test_dump_new_settings(self):
|
def test_dump_new_settings(self):
|
||||||
settings = Settings(
|
settings = Settings(
|
||||||
config_path=self.config_path, tool_repository_username="user1"
|
config_path=self.config_path, tool_repository_username="user1"
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ from click.testing import CliRunner
|
|||||||
import requests
|
import requests
|
||||||
|
|
||||||
from crewai.cli.organization.main import OrganizationCommand
|
from crewai.cli.organization.main import OrganizationCommand
|
||||||
from crewai.cli.cli import list, switch, current
|
from crewai.cli.cli import org_list, switch, current
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
@@ -16,44 +16,44 @@ def runner():
|
|||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def org_command():
|
def org_command():
|
||||||
with patch.object(OrganizationCommand, '__init__', return_value=None):
|
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||||
command = OrganizationCommand()
|
command = OrganizationCommand()
|
||||||
yield command
|
yield command
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_settings():
|
def mock_settings():
|
||||||
with patch('crewai.cli.organization.main.Settings') as mock_settings_class:
|
with patch("crewai.cli.organization.main.Settings") as mock_settings_class:
|
||||||
mock_settings_instance = MagicMock()
|
mock_settings_instance = MagicMock()
|
||||||
mock_settings_class.return_value = mock_settings_instance
|
mock_settings_class.return_value = mock_settings_instance
|
||||||
yield mock_settings_instance
|
yield mock_settings_instance
|
||||||
|
|
||||||
|
|
||||||
@patch('crewai.cli.cli.OrganizationCommand')
|
@patch("crewai.cli.cli.OrganizationCommand")
|
||||||
def test_org_list_command(mock_org_command_class, runner):
|
def test_org_list_command(mock_org_command_class, runner):
|
||||||
mock_org_instance = MagicMock()
|
mock_org_instance = MagicMock()
|
||||||
mock_org_command_class.return_value = mock_org_instance
|
mock_org_command_class.return_value = mock_org_instance
|
||||||
|
|
||||||
result = runner.invoke(list)
|
result = runner.invoke(org_list)
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_org_command_class.assert_called_once()
|
mock_org_command_class.assert_called_once()
|
||||||
mock_org_instance.list.assert_called_once()
|
mock_org_instance.list.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
@patch('crewai.cli.cli.OrganizationCommand')
|
@patch("crewai.cli.cli.OrganizationCommand")
|
||||||
def test_org_switch_command(mock_org_command_class, runner):
|
def test_org_switch_command(mock_org_command_class, runner):
|
||||||
mock_org_instance = MagicMock()
|
mock_org_instance = MagicMock()
|
||||||
mock_org_command_class.return_value = mock_org_instance
|
mock_org_command_class.return_value = mock_org_instance
|
||||||
|
|
||||||
result = runner.invoke(switch, ['test-id'])
|
result = runner.invoke(switch, ["test-id"])
|
||||||
|
|
||||||
assert result.exit_code == 0
|
assert result.exit_code == 0
|
||||||
mock_org_command_class.assert_called_once()
|
mock_org_command_class.assert_called_once()
|
||||||
mock_org_instance.switch.assert_called_once_with('test-id')
|
mock_org_instance.switch.assert_called_once_with("test-id")
|
||||||
|
|
||||||
|
|
||||||
@patch('crewai.cli.cli.OrganizationCommand')
|
@patch("crewai.cli.cli.OrganizationCommand")
|
||||||
def test_org_current_command(mock_org_command_class, runner):
|
def test_org_current_command(mock_org_command_class, runner):
|
||||||
mock_org_instance = MagicMock()
|
mock_org_instance = MagicMock()
|
||||||
mock_org_command_class.return_value = mock_org_instance
|
mock_org_command_class.return_value = mock_org_instance
|
||||||
@@ -67,18 +67,18 @@ def test_org_current_command(mock_org_command_class, runner):
|
|||||||
|
|
||||||
class TestOrganizationCommand(unittest.TestCase):
|
class TestOrganizationCommand(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
with patch.object(OrganizationCommand, '__init__', return_value=None):
|
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||||
self.org_command = OrganizationCommand()
|
self.org_command = OrganizationCommand()
|
||||||
self.org_command.plus_api_client = MagicMock()
|
self.org_command.plus_api_client = MagicMock()
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
@patch('crewai.cli.organization.main.Table')
|
@patch("crewai.cli.organization.main.Table")
|
||||||
def test_list_organizations_success(self, mock_table, mock_console):
|
def test_list_organizations_success(self, mock_table, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_response.json.return_value = [
|
mock_response.json.return_value = [
|
||||||
{"name": "Org 1", "uuid": "org-123"},
|
{"name": "Org 1", "uuid": "org-123"},
|
||||||
{"name": "Org 2", "uuid": "org-456"}
|
{"name": "Org 2", "uuid": "org-456"},
|
||||||
]
|
]
|
||||||
self.org_command.plus_api_client = MagicMock()
|
self.org_command.plus_api_client = MagicMock()
|
||||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||||
@@ -89,16 +89,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
|
|
||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_table.assert_called_once_with(title="Your Organizations")
|
mock_table.assert_called_once_with(title="Your Organizations")
|
||||||
mock_table.return_value.add_column.assert_has_calls([
|
mock_table.return_value.add_column.assert_has_calls(
|
||||||
call("Name", style="cyan"),
|
[call("Name", style="cyan"), call("ID", style="green")]
|
||||||
call("ID", style="green")
|
)
|
||||||
])
|
mock_table.return_value.add_row.assert_has_calls(
|
||||||
mock_table.return_value.add_row.assert_has_calls([
|
[call("Org 1", "org-123"), call("Org 2", "org-456")]
|
||||||
call("Org 1", "org-123"),
|
)
|
||||||
call("Org 2", "org-456")
|
|
||||||
])
|
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
def test_list_organizations_empty(self, mock_console):
|
def test_list_organizations_empty(self, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
@@ -110,33 +108,32 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
|
|
||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"You don't belong to any organizations yet.",
|
"You don't belong to any organizations yet.", style="yellow"
|
||||||
style="yellow"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
def test_list_organizations_api_error(self, mock_console):
|
def test_list_organizations_api_error(self, mock_console):
|
||||||
self.org_command.plus_api_client = MagicMock()
|
self.org_command.plus_api_client = MagicMock()
|
||||||
self.org_command.plus_api_client.get_organizations.side_effect = requests.exceptions.RequestException("API Error")
|
self.org_command.plus_api_client.get_organizations.side_effect = (
|
||||||
|
requests.exceptions.RequestException("API Error")
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
self.org_command.list()
|
self.org_command.list()
|
||||||
|
|
||||||
|
|
||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"Failed to retrieve organization list: API Error",
|
"Failed to retrieve organization list: API Error", style="bold red"
|
||||||
style="bold red"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
@patch('crewai.cli.organization.main.Settings')
|
@patch("crewai.cli.organization.main.Settings")
|
||||||
def test_switch_organization_success(self, mock_settings_class, mock_console):
|
def test_switch_organization_success(self, mock_settings_class, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_response.json.return_value = [
|
mock_response.json.return_value = [
|
||||||
{"name": "Org 1", "uuid": "org-123"},
|
{"name": "Org 1", "uuid": "org-123"},
|
||||||
{"name": "Test Org", "uuid": "test-id"}
|
{"name": "Test Org", "uuid": "test-id"},
|
||||||
]
|
]
|
||||||
self.org_command.plus_api_client = MagicMock()
|
self.org_command.plus_api_client = MagicMock()
|
||||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||||
@@ -151,17 +148,16 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
assert mock_settings_instance.org_name == "Test Org"
|
assert mock_settings_instance.org_name == "Test Org"
|
||||||
assert mock_settings_instance.org_uuid == "test-id"
|
assert mock_settings_instance.org_uuid == "test-id"
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"Successfully switched to Test Org (test-id)",
|
"Successfully switched to Test Org (test-id)", style="bold green"
|
||||||
style="bold green"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
def test_switch_organization_not_found(self, mock_console):
|
def test_switch_organization_not_found(self, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_response.raise_for_status = MagicMock()
|
mock_response.raise_for_status = MagicMock()
|
||||||
mock_response.json.return_value = [
|
mock_response.json.return_value = [
|
||||||
{"name": "Org 1", "uuid": "org-123"},
|
{"name": "Org 1", "uuid": "org-123"},
|
||||||
{"name": "Org 2", "uuid": "org-456"}
|
{"name": "Org 2", "uuid": "org-456"},
|
||||||
]
|
]
|
||||||
self.org_command.plus_api_client = MagicMock()
|
self.org_command.plus_api_client = MagicMock()
|
||||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||||
@@ -170,12 +166,11 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
|
|
||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"Organization with id 'non-existent-id' not found.",
|
"Organization with id 'non-existent-id' not found.", style="bold red"
|
||||||
style="bold red"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
@patch('crewai.cli.organization.main.Settings')
|
@patch("crewai.cli.organization.main.Settings")
|
||||||
def test_current_organization_with_org(self, mock_settings_class, mock_console):
|
def test_current_organization_with_org(self, mock_settings_class, mock_console):
|
||||||
mock_settings_instance = MagicMock()
|
mock_settings_instance = MagicMock()
|
||||||
mock_settings_instance.org_name = "Test Org"
|
mock_settings_instance.org_name = "Test Org"
|
||||||
@@ -186,12 +181,11 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
|
|
||||||
self.org_command.plus_api_client.get_organizations.assert_not_called()
|
self.org_command.plus_api_client.get_organizations.assert_not_called()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"Currently logged in to organization Test Org (test-id)",
|
"Currently logged in to organization Test Org (test-id)", style="bold green"
|
||||||
style="bold green"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
@patch('crewai.cli.organization.main.Settings')
|
@patch("crewai.cli.organization.main.Settings")
|
||||||
def test_current_organization_without_org(self, mock_settings_class, mock_console):
|
def test_current_organization_without_org(self, mock_settings_class, mock_console):
|
||||||
mock_settings_instance = MagicMock()
|
mock_settings_instance = MagicMock()
|
||||||
mock_settings_instance.org_uuid = None
|
mock_settings_instance.org_uuid = None
|
||||||
@@ -201,16 +195,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
|
|
||||||
assert mock_console.print.call_count == 3
|
assert mock_console.print.call_count == 3
|
||||||
mock_console.print.assert_any_call(
|
mock_console.print.assert_any_call(
|
||||||
"You're not currently logged in to any organization.",
|
"You're not currently logged in to any organization.", style="yellow"
|
||||||
style="yellow"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
def test_list_organizations_unauthorized(self, mock_console):
|
def test_list_organizations_unauthorized(self, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_http_error = requests.exceptions.HTTPError(
|
mock_http_error = requests.exceptions.HTTPError(
|
||||||
"401 Client Error: Unauthorized",
|
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||||
response=MagicMock(status_code=401)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_response.raise_for_status.side_effect = mock_http_error
|
mock_response.raise_for_status.side_effect = mock_http_error
|
||||||
@@ -221,15 +213,14 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||||
style="bold red"
|
style="bold red",
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch('crewai.cli.organization.main.console')
|
@patch("crewai.cli.organization.main.console")
|
||||||
def test_switch_organization_unauthorized(self, mock_console):
|
def test_switch_organization_unauthorized(self, mock_console):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_http_error = requests.exceptions.HTTPError(
|
mock_http_error = requests.exceptions.HTTPError(
|
||||||
"401 Client Error: Unauthorized",
|
"401 Client Error: Unauthorized", response=MagicMock(status_code=401)
|
||||||
response=MagicMock(status_code=401)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mock_response.raise_for_status.side_effect = mock_http_error
|
mock_response.raise_for_status.side_effect = mock_http_error
|
||||||
@@ -240,5 +231,5 @@ class TestOrganizationCommand(unittest.TestCase):
|
|||||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||||
mock_console.print.assert_called_once_with(
|
mock_console.print.assert_called_once_with(
|
||||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||||
style="bold red"
|
style="bold red",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
import os
|
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import MagicMock, patch, ANY
|
from unittest.mock import MagicMock, patch, ANY
|
||||||
|
|
||||||
from crewai.cli.plus_api import PlusAPI
|
from crewai.cli.plus_api import PlusAPI
|
||||||
|
from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
|
|
||||||
|
|
||||||
class TestPlusAPI(unittest.TestCase):
|
class TestPlusAPI(unittest.TestCase):
|
||||||
@@ -30,29 +30,41 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
def assert_request_with_org_id(self, mock_make_request, method: str, endpoint: str, **kwargs):
|
def assert_request_with_org_id(
|
||||||
|
self, mock_make_request, method: str, endpoint: str, **kwargs
|
||||||
|
):
|
||||||
mock_make_request.assert_called_once_with(
|
mock_make_request.assert_called_once_with(
|
||||||
method, f"https://app.crewai.com{endpoint}", headers={'Authorization': ANY, 'Content-Type': ANY, 'User-Agent': ANY, 'X-Crewai-Version': ANY, 'X-Crewai-Organization-Id': self.org_uuid}, **kwargs
|
method,
|
||||||
|
f"{DEFAULT_CREWAI_ENTERPRISE_URL}{endpoint}",
|
||||||
|
headers={
|
||||||
|
"Authorization": ANY,
|
||||||
|
"Content-Type": ANY,
|
||||||
|
"User-Agent": ANY,
|
||||||
|
"X-Crewai-Version": ANY,
|
||||||
|
"X-Crewai-Organization-Id": self.org_uuid,
|
||||||
|
},
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.Settings")
|
@patch("crewai.cli.plus_api.Settings")
|
||||||
@patch("requests.Session.request")
|
@patch("requests.Session.request")
|
||||||
def test_login_to_tool_repository_with_org_uuid(self, mock_make_request, mock_settings_class):
|
def test_login_to_tool_repository_with_org_uuid(
|
||||||
|
self, mock_make_request, mock_settings_class
|
||||||
|
):
|
||||||
mock_settings = MagicMock()
|
mock_settings = MagicMock()
|
||||||
mock_settings.org_uuid = self.org_uuid
|
mock_settings.org_uuid = self.org_uuid
|
||||||
|
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
mock_settings_class.return_value = mock_settings
|
mock_settings_class.return_value = mock_settings
|
||||||
# re-initialize Client
|
# re-initialize Client
|
||||||
self.api = PlusAPI(self.api_key)
|
self.api = PlusAPI(self.api_key)
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_make_request.return_value = mock_response
|
mock_make_request.return_value = mock_response
|
||||||
|
|
||||||
response = self.api.login_to_tool_repository()
|
response = self.api.login_to_tool_repository()
|
||||||
|
|
||||||
self.assert_request_with_org_id(
|
self.assert_request_with_org_id(
|
||||||
mock_make_request,
|
mock_make_request, "POST", "/crewai_plus/api/v1/tools/login"
|
||||||
'POST',
|
|
||||||
'/crewai_plus/api/v1/tools/login'
|
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@@ -66,28 +78,27 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
"GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.Settings")
|
@patch("crewai.cli.plus_api.Settings")
|
||||||
@patch("requests.Session.request")
|
@patch("requests.Session.request")
|
||||||
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
|
def test_get_agent_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||||
mock_settings = MagicMock()
|
mock_settings = MagicMock()
|
||||||
mock_settings.org_uuid = self.org_uuid
|
mock_settings.org_uuid = self.org_uuid
|
||||||
|
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
mock_settings_class.return_value = mock_settings
|
mock_settings_class.return_value = mock_settings
|
||||||
# re-initialize Client
|
# re-initialize Client
|
||||||
self.api = PlusAPI(self.api_key)
|
self.api = PlusAPI(self.api_key)
|
||||||
|
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_make_request.return_value = mock_response
|
mock_make_request.return_value = mock_response
|
||||||
|
|
||||||
response = self.api.get_agent("test_agent_handle")
|
response = self.api.get_agent("test_agent_handle")
|
||||||
|
|
||||||
self.assert_request_with_org_id(
|
self.assert_request_with_org_id(
|
||||||
mock_make_request,
|
mock_make_request, "GET", "/crewai_plus/api/v1/agents/test_agent_handle"
|
||||||
"GET",
|
|
||||||
"/crewai_plus/api/v1/agents/test_agent_handle"
|
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||||
def test_get_tool(self, mock_make_request):
|
def test_get_tool(self, mock_make_request):
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
@@ -98,12 +109,13 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.Settings")
|
@patch("crewai.cli.plus_api.Settings")
|
||||||
@patch("requests.Session.request")
|
@patch("requests.Session.request")
|
||||||
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
def test_get_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||||
mock_settings = MagicMock()
|
mock_settings = MagicMock()
|
||||||
mock_settings.org_uuid = self.org_uuid
|
mock_settings.org_uuid = self.org_uuid
|
||||||
|
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
mock_settings_class.return_value = mock_settings
|
mock_settings_class.return_value = mock_settings
|
||||||
# re-initialize Client
|
# re-initialize Client
|
||||||
self.api = PlusAPI(self.api_key)
|
self.api = PlusAPI(self.api_key)
|
||||||
@@ -115,9 +127,7 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
response = self.api.get_tool("test_tool_handle")
|
response = self.api.get_tool("test_tool_handle")
|
||||||
|
|
||||||
self.assert_request_with_org_id(
|
self.assert_request_with_org_id(
|
||||||
mock_make_request,
|
mock_make_request, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||||
"GET",
|
|
||||||
"/crewai_plus/api/v1/tools/test_tool_handle"
|
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@@ -147,12 +157,13 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@patch("crewai.cli.plus_api.Settings")
|
@patch("crewai.cli.plus_api.Settings")
|
||||||
@patch("requests.Session.request")
|
@patch("requests.Session.request")
|
||||||
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
def test_publish_tool_with_org_uuid(self, mock_make_request, mock_settings_class):
|
||||||
mock_settings = MagicMock()
|
mock_settings = MagicMock()
|
||||||
mock_settings.org_uuid = self.org_uuid
|
mock_settings.org_uuid = self.org_uuid
|
||||||
|
mock_settings.enterprise_base_url = DEFAULT_CREWAI_ENTERPRISE_URL
|
||||||
mock_settings_class.return_value = mock_settings
|
mock_settings_class.return_value = mock_settings
|
||||||
# re-initialize Client
|
# re-initialize Client
|
||||||
self.api = PlusAPI(self.api_key)
|
self.api = PlusAPI(self.api_key)
|
||||||
@@ -160,7 +171,7 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
# Set up mock response
|
# Set up mock response
|
||||||
mock_response = MagicMock()
|
mock_response = MagicMock()
|
||||||
mock_make_request.return_value = mock_response
|
mock_make_request.return_value = mock_response
|
||||||
|
|
||||||
handle = "test_tool_handle"
|
handle = "test_tool_handle"
|
||||||
public = True
|
public = True
|
||||||
version = "1.0.0"
|
version = "1.0.0"
|
||||||
@@ -180,12 +191,9 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
"description": description,
|
"description": description,
|
||||||
"available_exports": None,
|
"available_exports": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
self.assert_request_with_org_id(
|
self.assert_request_with_org_id(
|
||||||
mock_make_request,
|
mock_make_request, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||||
"POST",
|
|
||||||
"/crewai_plus/api/v1/tools",
|
|
||||||
json=expected_params
|
|
||||||
)
|
)
|
||||||
self.assertEqual(response, mock_response)
|
self.assertEqual(response, mock_response)
|
||||||
|
|
||||||
@@ -311,8 +319,11 @@ class TestPlusAPI(unittest.TestCase):
|
|||||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||||
)
|
)
|
||||||
|
|
||||||
@patch.dict(os.environ, {"CREWAI_BASE_URL": "https://custom-url.com/api"})
|
@patch("crewai.cli.plus_api.Settings")
|
||||||
def test_custom_base_url(self):
|
def test_custom_base_url(self, mock_settings_class):
|
||||||
|
mock_settings = MagicMock()
|
||||||
|
mock_settings.enterprise_base_url = "https://custom-url.com/api"
|
||||||
|
mock_settings_class.return_value = mock_settings
|
||||||
custom_api = PlusAPI("test_key")
|
custom_api = PlusAPI("test_key")
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
custom_api.base_url,
|
custom_api.base_url,
|
||||||
|
|||||||
91
tests/cli/test_settings_command.py
Normal file
91
tests/cli/test_settings_command.py
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
import tempfile
|
||||||
|
import unittest
|
||||||
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch, MagicMock, call
|
||||||
|
|
||||||
|
from crewai.cli.settings.main import SettingsCommand
|
||||||
|
from crewai.cli.config import (
|
||||||
|
Settings,
|
||||||
|
USER_SETTINGS_KEYS,
|
||||||
|
CLI_SETTINGS_KEYS,
|
||||||
|
DEFAULT_CLI_SETTINGS,
|
||||||
|
HIDDEN_SETTINGS_KEYS,
|
||||||
|
READONLY_SETTINGS_KEYS,
|
||||||
|
)
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
|
||||||
|
class TestSettingsCommand(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
self.test_dir = Path(tempfile.mkdtemp())
|
||||||
|
self.config_path = self.test_dir / "settings.json"
|
||||||
|
self.settings = Settings(config_path=self.config_path)
|
||||||
|
self.settings_command = SettingsCommand(
|
||||||
|
settings_kwargs={"config_path": self.config_path}
|
||||||
|
)
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
shutil.rmtree(self.test_dir)
|
||||||
|
|
||||||
|
@patch("crewai.cli.settings.main.console")
|
||||||
|
@patch("crewai.cli.settings.main.Table")
|
||||||
|
def test_list_settings(self, mock_table_class, mock_console):
|
||||||
|
mock_table_instance = MagicMock()
|
||||||
|
mock_table_class.return_value = mock_table_instance
|
||||||
|
|
||||||
|
self.settings_command.list()
|
||||||
|
|
||||||
|
# Tests that the table is created skipping hidden settings
|
||||||
|
mock_table_instance.add_row.assert_has_calls(
|
||||||
|
[
|
||||||
|
call(
|
||||||
|
field_name,
|
||||||
|
getattr(self.settings, field_name) or "Not set",
|
||||||
|
field_info.description,
|
||||||
|
)
|
||||||
|
for field_name, field_info in Settings.model_fields.items()
|
||||||
|
if field_name not in HIDDEN_SETTINGS_KEYS
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Tests that the table is printed
|
||||||
|
mock_console.print.assert_called_once_with(mock_table_instance)
|
||||||
|
|
||||||
|
def test_set_valid_keys(self):
|
||||||
|
valid_keys = Settings.model_fields.keys() - (
|
||||||
|
READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||||
|
)
|
||||||
|
for key in valid_keys:
|
||||||
|
test_value = f"some_value_for_{key}"
|
||||||
|
self.settings_command.set(key, test_value)
|
||||||
|
self.assertEqual(getattr(self.settings_command.settings, key), test_value)
|
||||||
|
|
||||||
|
def test_set_invalid_key(self):
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.settings_command.set("invalid_key", "value")
|
||||||
|
|
||||||
|
def test_set_readonly_keys(self):
|
||||||
|
for key in READONLY_SETTINGS_KEYS:
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.settings_command.set(key, "some_readonly_key_value")
|
||||||
|
|
||||||
|
def test_set_hidden_keys(self):
|
||||||
|
for key in HIDDEN_SETTINGS_KEYS:
|
||||||
|
with self.assertRaises(SystemExit):
|
||||||
|
self.settings_command.set(key, "some_hidden_key_value")
|
||||||
|
|
||||||
|
def test_reset_all_settings(self):
|
||||||
|
for key in USER_SETTINGS_KEYS + CLI_SETTINGS_KEYS:
|
||||||
|
setattr(self.settings_command.settings, key, f"custom_value_for_{key}")
|
||||||
|
self.settings_command.settings.dump()
|
||||||
|
|
||||||
|
self.settings_command.reset_all_settings()
|
||||||
|
|
||||||
|
print(USER_SETTINGS_KEYS)
|
||||||
|
for key in USER_SETTINGS_KEYS:
|
||||||
|
self.assertEqual(getattr(self.settings_command.settings, key), None)
|
||||||
|
|
||||||
|
for key in CLI_SETTINGS_KEYS:
|
||||||
|
self.assertEqual(
|
||||||
|
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS[key]
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user