diff --git a/lib/cli/tests/__init__.py b/lib/cli/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/cli/tests/authentication/__init__.py b/lib/cli/tests/authentication/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/cli/tests/authentication/providers/__init__.py b/lib/cli/tests/authentication/providers/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/cli/tests/authentication/providers/test_auth0.py b/lib/cli/tests/authentication/providers/test_auth0.py new file mode 100644 index 000000000..c91acf225 --- /dev/null +++ b/lib/cli/tests/authentication/providers/test_auth0.py @@ -0,0 +1,91 @@ +import pytest +from crewai_cli.authentication.main import Oauth2Settings +from crewai_cli.authentication.providers.auth0 import Auth0Provider + + + +class TestAuth0Provider: + + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="auth0", + domain="test-domain.auth0.com", + client_id="test-client-id", + audience="test-audience" + ) + self.provider = Auth0Provider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = Auth0Provider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "auth0" + assert provider.settings.domain == "test-domain.auth0.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://test-domain.auth0.com/oauth/device/code" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="my-company.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://my-company.auth0.com/oauth/device/code" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://test-domain.auth0.com/oauth/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="another-domain.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://another-domain.auth0.com/oauth/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://test-domain.auth0.com/.well-known/jwks.json" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="dev.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_url = "https://dev.auth0.com/.well-known/jwks.json" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://test-domain.auth0.com/" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="auth0", + domain="prod.auth0.com", + client_id="test-client", + audience="test-audience" + ) + provider = Auth0Provider(settings) + expected_issuer = "https://prod.auth0.com/" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" diff --git a/lib/cli/tests/authentication/providers/test_entra_id.py b/lib/cli/tests/authentication/providers/test_entra_id.py new file mode 100644 index 000000000..31ae3d018 --- /dev/null +++ b/lib/cli/tests/authentication/providers/test_entra_id.py @@ -0,0 +1,141 @@ +import pytest + +from crewai_cli.authentication.main import Oauth2Settings +from crewai_cli.authentication.providers.entra_id import EntraIdProvider + + +class TestEntraIdProvider: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "openid profile email api://crewai-cli-dev/read" + } + ) + self.provider = EntraIdProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = EntraIdProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "entra_id" + assert provider.settings.domain == "tenant-id-abcdef123456" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="my-company.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="another-domain.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="dev.entra.id", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + # For EntraID, the domain is the tenant ID. + settings = Oauth2Settings( + provider="entra_id", + domain="other-tenant-id-xpto", + client_id="test-client", + audience="test-audience", + ) + provider = EntraIdProvider(settings) + expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_assertion_error_when_none(self): + settings = Oauth2Settings( + provider="entra_id", + domain="test-tenant-id", + client_id="test-client-id", + audience=None, + ) + provider = EntraIdProvider(settings) + + with pytest.raises(ValueError, match="Audience is required"): + provider.get_audience() + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" + + def test_get_required_fields(self): + assert set(self.provider.get_required_fields()) == set(["scope"]) + + def test_get_oauth_scopes(self): + settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "api://crewai-cli-dev/read" + } + ) + provider = EntraIdProvider(settings) + assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"] + + def test_get_oauth_scopes_with_multiple_custom_scopes(self): + settings = Oauth2Settings( + provider="entra_id", + domain="tenant-id-abcdef123456", + client_id="test-client-id", + audience="test-audience", + extra={ + "scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2" + } + ) + provider = EntraIdProvider(settings) + assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"] + + def test_base_url(self): + assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456" \ No newline at end of file diff --git a/lib/cli/tests/authentication/providers/test_keycloak.py b/lib/cli/tests/authentication/providers/test_keycloak.py new file mode 100644 index 000000000..e9637da6f --- /dev/null +++ b/lib/cli/tests/authentication/providers/test_keycloak.py @@ -0,0 +1,138 @@ +import pytest + +from crewai_cli.authentication.main import Oauth2Settings +from crewai_cli.authentication.providers.keycloak import KeycloakProvider + + +class TestKeycloakProvider: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="keycloak", + domain="keycloak.example.com", + client_id="test-client-id", + audience="test-audience", + extra={ + "realm": "test-realm" + } + ) + self.provider = KeycloakProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = KeycloakProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "keycloak" + assert provider.settings.domain == "keycloak.example.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + assert provider.settings.extra.get("realm") == "test-realm" + + def test_get_authorize_url(self): + expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="keycloak", + domain="auth.company.com", + client_id="test-client", + audience="test-audience", + extra={ + "realm": "my-realm" + } + ) + provider = KeycloakProvider(settings) + expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="keycloak", + domain="sso.enterprise.com", + client_id="test-client", + audience="test-audience", + extra={ + "realm": "enterprise-realm" + } + ) + provider = KeycloakProvider(settings) + expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="keycloak", + domain="identity.org", + client_id="test-client", + audience="test-audience", + extra={ + "realm": "org-realm" + } + ) + provider = KeycloakProvider(settings) + expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://keycloak.example.com/realms/test-realm" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="keycloak", + domain="login.myapp.io", + client_id="test-client", + audience="test-audience", + extra={ + "realm": "app-realm" + } + ) + provider = KeycloakProvider(settings) + expected_issuer = "https://login.myapp.io/realms/app-realm" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" + + def test_get_required_fields(self): + assert self.provider.get_required_fields() == ["realm"] + + def test_oauth2_base_url(self): + assert self.provider._oauth2_base_url() == "https://keycloak.example.com" + + def test_oauth2_base_url_strips_https_prefix(self): + settings = Oauth2Settings( + provider="keycloak", + domain="https://keycloak.example.com", + client_id="test-client-id", + audience="test-audience", + extra={ + "realm": "test-realm" + } + ) + provider = KeycloakProvider(settings) + assert provider._oauth2_base_url() == "https://keycloak.example.com" + + def test_oauth2_base_url_strips_http_prefix(self): + settings = Oauth2Settings( + provider="keycloak", + domain="http://keycloak.example.com", + client_id="test-client-id", + audience="test-audience", + extra={ + "realm": "test-realm" + } + ) + provider = KeycloakProvider(settings) + assert provider._oauth2_base_url() == "https://keycloak.example.com" diff --git a/lib/cli/tests/authentication/providers/test_okta.py b/lib/cli/tests/authentication/providers/test_okta.py new file mode 100644 index 000000000..42d292508 --- /dev/null +++ b/lib/cli/tests/authentication/providers/test_okta.py @@ -0,0 +1,257 @@ +import pytest + +from crewai_cli.authentication.main import Oauth2Settings +from crewai_cli.authentication.providers.okta import OktaProvider + + +class TestOktaProvider: + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience="test-audience", + ) + self.provider = OktaProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = OktaProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "okta" + assert provider.settings.domain == "test-domain.okta.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="my-company.okta.com", + client_id="test-client", + audience="test-audience", + ) + provider = OktaProvider(settings) + expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize" + assert provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize" + assert provider.get_authorize_url() == expected_url + + def test_get_authorize_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="another-domain.okta.com", + client_id="test-client", + audience="test-audience", + ) + provider = OktaProvider(settings) + expected_url = "https://another-domain.okta.com/oauth2/default/v1/token" + assert provider.get_token_url() == expected_url + + def test_get_token_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token" + assert provider.get_token_url() == expected_url + + def test_get_token_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="dev.okta.com", + client_id="test-client", + audience="test-audience", + ) + provider = OktaProvider(settings) + expected_url = "https://dev.okta.com/oauth2/default/v1/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_jwks_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_url = "https://test-domain.okta.com/oauth2/v1/keys" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://test-domain.okta.com/oauth2/default" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="okta", + domain="prod.okta.com", + client_id="test-client", + audience="test-audience", + ) + provider = OktaProvider(settings) + expected_issuer = "https://prod.okta.com/oauth2/default" + assert provider.get_issuer() == expected_issuer + + def test_get_issuer_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + provider = OktaProvider(settings) + expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777" + assert provider.get_issuer() == expected_issuer + + def test_get_issuer_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + expected_issuer = "https://test-domain.okta.com" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_assertion_error_when_none(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + ) + provider = OktaProvider(settings) + + with pytest.raises(ValueError, match="Audience is required"): + provider.get_audience() + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" + + def test_get_required_fields(self): + assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"]) + + def test_oauth2_base_url(self): + assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default" + + def test_oauth2_base_url_with_custom_authorization_server_name(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": False, + "authorization_server_name": "my_auth_server_xxxAAA777" + } + ) + + provider = OktaProvider(settings) + assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777" + + def test_oauth2_base_url_when_using_org_auth_server(self): + settings = Oauth2Settings( + provider="okta", + domain="test-domain.okta.com", + client_id="test-client-id", + audience=None, + extra={ + "using_org_auth_server": True, + "authorization_server_name": None + } + ) + provider = OktaProvider(settings) + assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2" \ No newline at end of file diff --git a/lib/cli/tests/authentication/providers/test_workos.py b/lib/cli/tests/authentication/providers/test_workos.py new file mode 100644 index 000000000..2323e8d95 --- /dev/null +++ b/lib/cli/tests/authentication/providers/test_workos.py @@ -0,0 +1,100 @@ +import pytest +from crewai_cli.authentication.main import Oauth2Settings +from crewai_cli.authentication.providers.workos import WorkosProvider + + +class TestWorkosProvider: + + @pytest.fixture(autouse=True) + def setup_method(self): + self.valid_settings = Oauth2Settings( + provider="workos", + domain="login.company.com", + client_id="test-client-id", + audience="test-audience" + ) + self.provider = WorkosProvider(self.valid_settings) + + def test_initialization_with_valid_settings(self): + provider = WorkosProvider(self.valid_settings) + assert provider.settings == self.valid_settings + assert provider.settings.provider == "workos" + assert provider.settings.domain == "login.company.com" + assert provider.settings.client_id == "test-client-id" + assert provider.settings.audience == "test-audience" + + def test_get_authorize_url(self): + expected_url = "https://login.company.com/oauth2/device_authorization" + assert self.provider.get_authorize_url() == expected_url + + def test_get_authorize_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="login.example.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://login.example.com/oauth2/device_authorization" + assert provider.get_authorize_url() == expected_url + + def test_get_token_url(self): + expected_url = "https://login.company.com/oauth2/token" + assert self.provider.get_token_url() == expected_url + + def test_get_token_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="api.workos.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://api.workos.com/oauth2/token" + assert provider.get_token_url() == expected_url + + def test_get_jwks_url(self): + expected_url = "https://login.company.com/oauth2/jwks" + assert self.provider.get_jwks_url() == expected_url + + def test_get_jwks_url_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="auth.enterprise.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_url = "https://auth.enterprise.com/oauth2/jwks" + assert provider.get_jwks_url() == expected_url + + def test_get_issuer(self): + expected_issuer = "https://login.company.com" + assert self.provider.get_issuer() == expected_issuer + + def test_get_issuer_with_different_domain(self): + settings = Oauth2Settings( + provider="workos", + domain="sso.company.com", + client_id="test-client", + audience="test-audience" + ) + provider = WorkosProvider(settings) + expected_issuer = "https://sso.company.com" + assert provider.get_issuer() == expected_issuer + + def test_get_audience(self): + assert self.provider.get_audience() == "test-audience" + + def test_get_audience_fallback_to_default(self): + settings = Oauth2Settings( + provider="workos", + domain="login.company.com", + client_id="test-client-id", + audience=None + ) + provider = WorkosProvider(settings) + assert provider.get_audience() == "" + + def test_get_client_id(self): + assert self.provider.get_client_id() == "test-client-id" diff --git a/lib/cli/tests/authentication/test_auth_main.py b/lib/cli/tests/authentication/test_auth_main.py new file mode 100644 index 000000000..362ecf827 --- /dev/null +++ b/lib/cli/tests/authentication/test_auth_main.py @@ -0,0 +1,348 @@ +from datetime import datetime, timedelta +from unittest.mock import MagicMock, call, patch + +import pytest +import httpx +from crewai_cli.authentication.main import AuthenticationCommand +from crewai_cli.constants import ( + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, +) + + +class TestAuthenticationCommand: + def setup_method(self): + # Mock Settings so we always use default constants regardless of local config. + with patch("crewai_cli.authentication.main.Settings") as mock_settings: + instance = mock_settings.return_value + instance.oauth2_provider = "workos" + instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN + instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID + instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE + instance.oauth2_extra = {} + self.auth_command = AuthenticationCommand() + + @pytest.mark.parametrize( + "user_provider,expected_urls", + [ + ( + "workos", + { + "device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization", + "token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token", + "client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID, + "audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + "domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + }, + ), + ], + ) + @patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code") + @patch( + "crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions" + ) + @patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token") + @patch("crewai_cli.authentication.main.console.print") + def test_login( + self, + mock_console_print, + mock_poll, + mock_display, + mock_get_device, + user_provider, + expected_urls, + ): + mock_get_device.return_value = { + "device_code": "test_code", + "user_code": "123456", + } + + self.auth_command.login() + + mock_console_print.assert_called_once_with( + "Signing in to CrewAI AMP...\n", style="bold blue" + ) + mock_get_device.assert_called_once() + mock_display.assert_called_once_with( + {"device_code": "test_code", "user_code": "123456"} + ) + mock_poll.assert_called_once_with( + {"device_code": "test_code", "user_code": "123456"}, + ) + assert ( + self.auth_command.oauth2_provider.get_client_id() + == expected_urls["client_id"] + ) + assert ( + self.auth_command.oauth2_provider.get_audience() + == expected_urls["audience"] + ) + assert ( + self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"] + ) + + @patch("crewai_cli.authentication.main.webbrowser") + @patch("crewai_cli.authentication.main.console.print") + def test_display_auth_instructions(self, mock_console_print, mock_webbrowser): + device_code_data = { + "verification_uri_complete": "https://example.com/auth", + "user_code": "123456", + } + + self.auth_command._display_auth_instructions(device_code_data) + + expected_calls = [ + call("1. Navigate to: ", "https://example.com/auth"), + call("2. Enter the following code: ", "123456"), + ] + mock_console_print.assert_has_calls(expected_calls) + mock_webbrowser.open.assert_called_once_with("https://example.com/auth") + + @pytest.mark.parametrize( + "user_provider,jwt_config", + [ + ( + "workos", + { + "jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks", + "issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}", + "audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE, + }, + ), + ], + ) + @pytest.mark.parametrize("has_expiration", [True, False]) + @patch("crewai_cli.authentication.main.validate_jwt_token") + @patch("crewai_cli.authentication.main.TokenManager.save_tokens") + def test_validate_and_save_token( + self, + mock_save_tokens, + mock_validate_jwt, + user_provider, + jwt_config, + has_expiration, + ): + from crewai_cli.authentication.main import Oauth2Settings + from crewai_cli.authentication.providers.workos import WorkosProvider + + if user_provider == "workos": + self.auth_command.oauth2_provider = WorkosProvider( + settings=Oauth2Settings( + provider=user_provider, + client_id="test-client-id", + domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN, + audience=jwt_config["audience"], + ) + ) + + token_data = {"access_token": "test_access_token", "id_token": "test_id_token"} + + if has_expiration: + future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp()) + decoded_token = {"exp": future_timestamp} + else: + decoded_token = {} + + mock_validate_jwt.return_value = decoded_token + + self.auth_command._validate_and_save_token(token_data) + + mock_validate_jwt.assert_called_once_with( + jwt_token="test_access_token", + jwks_url=jwt_config["jwks_url"], + issuer=jwt_config["issuer"], + audience=jwt_config["audience"], + ) + + if has_expiration: + mock_save_tokens.assert_called_once_with( + "test_access_token", future_timestamp + ) + else: + mock_save_tokens.assert_called_once_with("test_access_token", 0) + + @patch("crewai_cli.tools.main.ToolCommand") + @patch("crewai_cli.authentication.main.Settings") + @patch("crewai_cli.authentication.main.console.print") + def test_login_to_tool_repository_success( + self, mock_console_print, mock_settings, mock_tool_command + ): + mock_tool_instance = MagicMock() + mock_tool_command.return_value = mock_tool_instance + + mock_settings_instance = MagicMock() + mock_settings_instance.org_name = "Test Org" + mock_settings_instance.org_uuid = "test-uuid-123" + mock_settings.return_value = mock_settings_instance + + self.auth_command._login_to_tool_repository() + + mock_tool_command.assert_called_once() + mock_tool_instance.login.assert_called_once() + + expected_calls = [ + call( + "Now logging you in to the Tool Repository... ", + style="bold blue", + end="", + ), + call("Success!\n", style="bold green"), + call( + "You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]", + style="green", + ), + ] + mock_console_print.assert_has_calls(expected_calls) + + @patch("crewai_cli.tools.main.ToolCommand") + @patch("crewai_cli.authentication.main.console.print") + def test_login_to_tool_repository_error( + self, mock_console_print, mock_tool_command + ): + mock_tool_instance = MagicMock() + mock_tool_instance.login.side_effect = Exception("Tool repository error") + mock_tool_command.return_value = mock_tool_instance + + self.auth_command._login_to_tool_repository() + + mock_tool_command.assert_called_once() + mock_tool_instance.login.assert_called_once() + + expected_calls = [ + call( + "Now logging you in to the Tool Repository... ", + style="bold blue", + end="", + ), + call( + "\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.", + style="yellow", + ), + call( + "Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n", + style="yellow", + ), + ] + mock_console_print.assert_has_calls(expected_calls) + + @patch("crewai_cli.authentication.main.httpx.post") + def test_get_device_code(self, mock_post): + mock_response = MagicMock() + mock_response.json.return_value = { + "device_code": "test_device_code", + "user_code": "123456", + "verification_uri_complete": "https://example.com/auth", + } + mock_post.return_value = mock_response + + self.auth_command.oauth2_provider = MagicMock() + self.auth_command.oauth2_provider.get_client_id.return_value = "test_client" + self.auth_command.oauth2_provider.get_authorize_url.return_value = ( + "https://example.com/device" + ) + self.auth_command.oauth2_provider.get_audience.return_value = "test_audience" + self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"] + + result = self.auth_command._get_device_code() + + mock_post.assert_called_once_with( + url="https://example.com/device", + data={ + "client_id": "test_client", + "scope": "openid profile email", + "audience": "test_audience", + }, + timeout=20, + ) + + assert result == { + "device_code": "test_device_code", + "user_code": "123456", + "verification_uri_complete": "https://example.com/auth", + } + + @patch("crewai_cli.authentication.main.httpx.post") + @patch("crewai_cli.authentication.main.console.print") + def test_poll_for_token_success(self, mock_console_print, mock_post): + mock_response_success = MagicMock() + mock_response_success.status_code = 200 + mock_response_success.json.return_value = { + "access_token": "test_access_token", + "id_token": "test_id_token", + } + mock_post.return_value = mock_response_success + + device_code_data = {"device_code": "test_device_code", "interval": 1} + + with ( + patch.object( + self.auth_command, "_validate_and_save_token" + ) as mock_validate, + patch.object( + self.auth_command, "_login_to_tool_repository" + ) as mock_tool_login, + ): + self.auth_command.oauth2_provider = MagicMock() + self.auth_command.oauth2_provider.get_token_url.return_value = ( + "https://example.com/token" + ) + self.auth_command.oauth2_provider.get_client_id.return_value = "test_client" + + self.auth_command._poll_for_token(device_code_data) + + mock_post.assert_called_once_with( + "https://example.com/token", + data={ + "grant_type": "urn:ietf:params:oauth:grant-type:device_code", + "device_code": "test_device_code", + "client_id": "test_client", + }, + timeout=30, + ) + + mock_validate.assert_called_once() + mock_tool_login.assert_called_once() + + expected_calls = [ + call("\nWaiting for authentication... ", style="bold blue", end=""), + call("Success!", style="bold green"), + call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"), + ] + mock_console_print.assert_has_calls(expected_calls) + + @patch("crewai_cli.authentication.main.httpx.post") + @patch("crewai_cli.authentication.main.console.print") + def test_poll_for_token_timeout(self, mock_console_print, mock_post): + mock_response_pending = MagicMock() + mock_response_pending.status_code = 400 + mock_response_pending.json.return_value = {"error": "authorization_pending"} + mock_post.return_value = mock_response_pending + + device_code_data = { + "device_code": "test_device_code", + "interval": 0.1, # Short interval for testing + } + + self.auth_command._poll_for_token(device_code_data) + + mock_console_print.assert_any_call( + "Timeout: Failed to get the token. Please try again.", style="bold red" + ) + + @patch("crewai_cli.authentication.main.httpx.post") + def test_poll_for_token_error(self, mock_post): + """Test the method to poll for token (error path).""" + # Setup mock to return error + mock_response_error = MagicMock() + mock_response_error.status_code = 400 + mock_response_error.json.return_value = { + "error": "access_denied", + "error_description": "User denied access", + } + mock_post.return_value = mock_response_error + + device_code_data = {"device_code": "test_device_code", "interval": 1} + + with pytest.raises(httpx.HTTPError): + self.auth_command._poll_for_token(device_code_data) diff --git a/lib/cli/tests/authentication/test_utils.py b/lib/cli/tests/authentication/test_utils.py new file mode 100644 index 000000000..fd8f21921 --- /dev/null +++ b/lib/cli/tests/authentication/test_utils.py @@ -0,0 +1,107 @@ +import unittest +from unittest.mock import MagicMock, patch + +import jwt + +from crewai_cli.authentication.utils import validate_jwt_token + + +@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock()) +@patch("crewai_cli.authentication.utils.jwt") +class TestUtils(unittest.TestCase): + def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.return_value = {"exp": 1719859200} + + # Create signing key object mock with a .key attribute + mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock( + key="mock_signing_key" + ) + + jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105 + + decoded_token = validate_jwt_token( + jwt_token=jwt_token, + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + mock_jwt.decode.assert_called_with( + jwt_token, + "mock_signing_key", + algorithms=["RS256"], + audience="app_id_xxxx", + issuer="https://mock_issuer", + leeway=10.0, + options={ + "verify_signature": True, + "verify_exp": True, + "verify_nbf": True, + "verify_iat": True, + "require": ["exp", "iat", "iss", "aud", "sub"], + }, + ) + mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url") + self.assertEqual(decoded_token, {"exp": 1719859200}) + + def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.side_effect = jwt.ExpiredSignatureError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.side_effect = jwt.InvalidAudienceError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.side_effect = jwt.InvalidIssuerError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + def test_validate_jwt_token_missing_required_claims( + self, mock_jwt, mock_pyjwkclient + ): + mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) + + def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient): + mock_jwt.decode.side_effect = jwt.InvalidTokenError + with self.assertRaises(Exception): # noqa: B017 + validate_jwt_token( + jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106 + jwks_url="https://mock_jwks_url", + issuer="https://mock_issuer", + audience="app_id_xxxx", + ) diff --git a/lib/cli/tests/deploy/__init__.py b/lib/cli/tests/deploy/__init__.py new file mode 100644 index 000000000..3b2cf1906 --- /dev/null +++ b/lib/cli/tests/deploy/__init__.py @@ -0,0 +1 @@ +"""Tests for CLI deploy.""" diff --git a/lib/cli/tests/deploy/test_deploy_main.py b/lib/cli/tests/deploy/test_deploy_main.py new file mode 100644 index 000000000..96f769413 --- /dev/null +++ b/lib/cli/tests/deploy/test_deploy_main.py @@ -0,0 +1,266 @@ +import sys +import unittest +from io import StringIO +from unittest.mock import MagicMock, Mock, patch + +import pytest +import json + +import httpx +from crewai_cli.deploy.main import DeployCommand +from crewai_cli.utils import parse_toml + + +class TestDeployCommand(unittest.TestCase): + @patch("crewai_cli.command.get_auth_token") + @patch("crewai_cli.deploy.main.get_project_name") + @patch("crewai_cli.command.PlusAPI") + def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token): + self.mock_get_auth_token = mock_get_auth_token + self.mock_get_project_name = mock_get_project_name + self.mock_plus_api = mock_plus_api + + self.mock_get_auth_token.return_value = "test_token" + self.mock_get_project_name.return_value = "test_project" + + self.deploy_command = DeployCommand() + self.mock_client = self.deploy_command.plus_api_client + + def test_init_success(self): + self.assertEqual(self.deploy_command.project_name, "test_project") + self.mock_plus_api.assert_called_once_with(api_key="test_token") + + @patch("crewai_cli.command.get_auth_token") + def test_init_failure(self, mock_get_auth_token): + mock_get_auth_token.side_effect = Exception("Auth failed") + + with self.assertRaises(SystemExit): + DeployCommand() + + def test_validate_response_successful_response(self): + mock_response = Mock(spec=httpx.Response) + mock_response.json.return_value = {"message": "Success"} + mock_response.status_code = 200 + mock_response.is_success = True + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._validate_response(mock_response) + assert fake_out.getvalue() == "" + + def test_validate_response_json_decode_error(self): + mock_response = Mock(spec=httpx.Response) + mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0) + mock_response.status_code = 500 + mock_response.content = b"Invalid JSON" + + with patch("sys.stdout", new=StringIO()) as fake_out: + with pytest.raises(SystemExit): + self.deploy_command._validate_response(mock_response) + output = fake_out.getvalue() + assert ( + "Failed to parse response from Enterprise API failed. Details:" + in output + ) + assert "Status Code: 500" in output + assert "Response:\nInvalid JSON" in output + + def test_validate_response_422_error(self): + mock_response = Mock(spec=httpx.Response) + mock_response.json.return_value = { + "field1": ["Error message 1"], + "field2": ["Error message 2"], + } + mock_response.status_code = 422 + mock_response.is_success = False + + with patch("sys.stdout", new=StringIO()) as fake_out: + with pytest.raises(SystemExit): + self.deploy_command._validate_response(mock_response) + output = fake_out.getvalue() + assert ( + "Failed to complete operation. Please fix the following errors:" + in output + ) + assert "Field1 Error message 1" in output + assert "Field2 Error message 2" in output + + def test_validate_response_other_error(self): + mock_response = Mock(spec=httpx.Response) + mock_response.json.return_value = {"error": "Something went wrong"} + mock_response.status_code = 500 + mock_response.is_success = False + + with patch("sys.stdout", new=StringIO()) as fake_out: + with pytest.raises(SystemExit): + self.deploy_command._validate_response(mock_response) + output = fake_out.getvalue() + assert "Request to Enterprise API failed. Details:" in output + assert "Details:\nSomething went wrong" in output + + def test_standard_no_param_error_message(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._standard_no_param_error_message() + self.assertIn("No UUID provided", fake_out.getvalue()) + + def test_display_deployment_info(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._display_deployment_info( + {"uuid": "test-uuid", "status": "deployed"} + ) + self.assertIn("Deploying the crew...", fake_out.getvalue()) + self.assertIn("test-uuid", fake_out.getvalue()) + self.assertIn("deployed", fake_out.getvalue()) + + def test_display_logs(self): + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command._display_logs( + [{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}] + ) + self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue()) + + @patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info") + def test_deploy_with_uuid(self, mock_display): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"uuid": "test-uuid"} + self.mock_client.deploy_by_uuid.return_value = mock_response + + self.deploy_command.deploy(uuid="test-uuid") + + self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid") + mock_display.assert_called_once_with({"uuid": "test-uuid"}) + + @patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info") + def test_deploy_with_project_name(self, mock_display): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"uuid": "test-uuid"} + self.mock_client.deploy_by_name.return_value = mock_response + + self.deploy_command.deploy() + + self.mock_client.deploy_by_name.assert_called_once_with("test_project") + mock_display.assert_called_once_with({"uuid": "test-uuid"}) + + @patch("crewai_cli.deploy.main.fetch_and_json_env_file") + @patch("crewai_cli.deploy.main.git.Repository.origin_url") + @patch("builtins.input") + def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env): + mock_fetch_env.return_value = {"ENV_VAR": "value"} + mock_git_origin_url.return_value = "https://github.com/test/repo.git" + mock_input.return_value = "" + + mock_response = MagicMock() + mock_response.status_code = 201 + mock_response.json.return_value = {"uuid": "new-uuid", "status": "created"} + self.mock_client.create_crew.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.create_crew() + self.assertIn("Deployment created successfully!", fake_out.getvalue()) + self.assertIn("new-uuid", fake_out.getvalue()) + + def test_list_crews(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"name": "Crew1", "uuid": "uuid1", "status": "active"}, + {"name": "Crew2", "uuid": "uuid2", "status": "inactive"}, + ] + self.mock_client.list_crews.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.list_crews() + self.assertIn("Crew1 (uuid1) active", fake_out.getvalue()) + self.assertIn("Crew2 (uuid2) inactive", fake_out.getvalue()) + + def test_get_crew_status(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = {"name": "InternalCrew", "status": "active"} + self.mock_client.crew_status_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.get_crew_status() + self.assertIn("InternalCrew", fake_out.getvalue()) + self.assertIn("active", fake_out.getvalue()) + + def test_get_crew_logs(self): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = [ + {"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"}, + {"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"}, + ] + self.mock_client.crew_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.get_crew_logs(None) + self.assertIn("2023-01-01 - INFO: Log1", fake_out.getvalue()) + self.assertIn("2023-01-02 - ERROR: Log2", fake_out.getvalue()) + + def test_remove_crew(self): + mock_response = MagicMock() + mock_response.status_code = 204 + self.mock_client.delete_crew_by_name.return_value = mock_response + + with patch("sys.stdout", new=StringIO()) as fake_out: + self.deploy_command.remove_crew(None) + self.assertIn( + "Crew 'test_project' removed successfully", fake_out.getvalue() + ) + + @unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+") + def test_parse_toml_python_311_plus(self): + toml_content = """ + [tool.poetry] + name = "test_project" + version = "0.1.0" + + [tool.poetry.dependencies] + python = "^3.11" + crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" } + """ + parsed = parse_toml(toml_content) + self.assertEqual(parsed["tool"]["poetry"]["name"], "test_project") + + @patch( + "builtins.open", + new_callable=unittest.mock.mock_open, + read_data=""" + [project] + name = "test_project" + version = "0.1.0" + requires-python = ">=3.10,<3.14" + dependencies = ["crewai"] + """, + ) + def test_get_project_name_python_310(self, mock_open): + from crewai_cli.utils import get_project_name + + project_name = get_project_name() + print("project_name", project_name) + self.assertEqual(project_name, "test_project") + + @unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+") + @patch( + "builtins.open", + new_callable=unittest.mock.mock_open, + read_data=""" + [project] + name = "test_project" + version = "0.1.0" + requires-python = ">=3.10,<3.14" + dependencies = ["crewai"] + """, + ) + def test_get_project_name_python_311_plus(self, mock_open): + from crewai_cli.utils import get_project_name + + project_name = get_project_name() + self.assertEqual(project_name, "test_project") + + def test_get_crewai_version(self): + from crewai_cli.version import get_crewai_version + + assert isinstance(get_crewai_version(), str) diff --git a/lib/cli/tests/enterprise/__init__.py b/lib/cli/tests/enterprise/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/cli/tests/enterprise/test_main.py b/lib/cli/tests/enterprise/test_main.py new file mode 100644 index 000000000..988c55ab4 --- /dev/null +++ b/lib/cli/tests/enterprise/test_main.py @@ -0,0 +1,158 @@ +import tempfile +import unittest +from pathlib import Path +from unittest.mock import Mock, patch + +import json + +import httpx + +from crewai_cli.enterprise.main import EnterpriseConfigureCommand +from crewai_cli.settings.main import SettingsCommand +import shutil + + +class TestEnterpriseConfigureCommand(unittest.TestCase): + def setUp(self): + self.test_dir = Path(tempfile.mkdtemp()) + self.config_path = self.test_dir / "settings.json" + + with patch('crewai_cli.enterprise.main.SettingsCommand') as mock_settings_command_class: + self.mock_settings_command = Mock(spec=SettingsCommand) + mock_settings_command_class.return_value = self.mock_settings_command + + self.enterprise_command = EnterpriseConfigureCommand() + + def tearDown(self): + shutil.rmtree(self.test_dir) + + @patch('crewai_cli.enterprise.main.httpx.get') + @patch('crewai_cli.enterprise.main.get_crewai_version') + def test_successful_configuration(self, mock_get_version, mock_requests_get): + mock_get_version.return_value = "1.0.0" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'audience': 'test_audience', + 'domain': 'test.domain.com', + 'device_authorization_client_id': 'test_client_id', + 'provider': 'workos', + 'extra': {} + } + mock_requests_get.return_value = mock_response + + enterprise_url = "https://enterprise.example.com" + self.enterprise_command.configure(enterprise_url) + + expected_headers = { + "Content-Type": "application/json", + "User-Agent": "CrewAI-CLI/1.0.0", + "X-Crewai-Version": "1.0.0", + } + mock_requests_get.assert_called_once_with( + "https://enterprise.example.com/auth/parameters", + timeout=30, + headers=expected_headers + ) + + expected_calls = [ + ('enterprise_base_url', 'https://enterprise.example.com'), + ('oauth2_provider', 'workos'), + ('oauth2_audience', 'test_audience'), + ('oauth2_client_id', 'test_client_id'), + ('oauth2_domain', 'test.domain.com'), + ('oauth2_extra', {}) + ] + + actual_calls = self.mock_settings_command.set.call_args_list + self.assertEqual(len(actual_calls), 6) + + for i, (key, value) in enumerate(expected_calls): + call_args = actual_calls[i][0] + self.assertEqual(call_args[0], key) + self.assertEqual(call_args[1], value) + + @patch('crewai_cli.enterprise.main.httpx.get') + @patch('crewai_cli.enterprise.main.get_crewai_version') + def test_http_error_handling(self, mock_get_version, mock_requests_get): + mock_get_version.return_value = "1.0.0" + + mock_response = Mock() + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + "404 Not Found", + request=httpx.Request("GET", "http://test"), + response=httpx.Response(404), + ) + mock_requests_get.return_value = mock_response + + with self.assertRaises(SystemExit): + self.enterprise_command.configure("https://enterprise.example.com") + + @patch('crewai_cli.enterprise.main.httpx.get') + @patch('crewai_cli.enterprise.main.get_crewai_version') + def test_invalid_json_response(self, mock_get_version, mock_requests_get): + mock_get_version.return_value = "1.0.0" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0) + mock_requests_get.return_value = mock_response + + with self.assertRaises(SystemExit): + self.enterprise_command.configure("https://enterprise.example.com") + + @patch('crewai_cli.enterprise.main.httpx.get') + @patch('crewai_cli.enterprise.main.get_crewai_version') + def test_missing_required_fields(self, mock_get_version, mock_requests_get): + mock_get_version.return_value = "1.0.0" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'audience': 'test_audience', + } + mock_requests_get.return_value = mock_response + + with self.assertRaises(SystemExit): + self.enterprise_command.configure("https://enterprise.example.com") + + @patch('crewai_cli.enterprise.main.httpx.get') + @patch('crewai_cli.enterprise.main.get_crewai_version') + def test_settings_update_error(self, mock_get_version, mock_requests_get): + mock_get_version.return_value = "1.0.0" + + mock_response = Mock() + mock_response.status_code = 200 + mock_response.raise_for_status.return_value = None + mock_response.json.return_value = { + 'audience': 'test_audience', + 'domain': 'test.domain.com', + 'device_authorization_client_id': 'test_client_id', + 'provider': 'workos' + } + mock_requests_get.return_value = mock_response + + self.mock_settings_command.set.side_effect = Exception("Settings update failed") + + with self.assertRaises(SystemExit): + self.enterprise_command.configure("https://enterprise.example.com") + + def test_url_trailing_slash_removal(self): + with patch.object(self.enterprise_command, '_fetch_oauth_config') as mock_fetch, \ + patch.object(self.enterprise_command, '_update_oauth_settings') as mock_update: + + mock_fetch.return_value = { + 'audience': 'test_audience', + 'domain': 'test.domain.com', + 'device_authorization_client_id': 'test_client_id', + 'provider': 'workos' + } + + self.enterprise_command.configure("https://enterprise.example.com/") + + mock_fetch.assert_called_once_with("https://enterprise.example.com") + mock_update.assert_called_once() diff --git a/lib/cli/tests/organization/__init__.py b/lib/cli/tests/organization/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/lib/cli/tests/organization/__init__.py @@ -0,0 +1 @@ + diff --git a/lib/cli/tests/organization/test_main.py b/lib/cli/tests/organization/test_main.py new file mode 100644 index 000000000..36eb99d9f --- /dev/null +++ b/lib/cli/tests/organization/test_main.py @@ -0,0 +1,239 @@ +import unittest +from unittest.mock import MagicMock, patch, call + +import pytest +from click.testing import CliRunner +import httpx + +from crewai_cli.organization.main import OrganizationCommand +from crewai_cli.cli import org_list, switch, current + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def org_command(): + with patch.object(OrganizationCommand, "__init__", return_value=None): + command = OrganizationCommand() + yield command + + +@pytest.fixture +def mock_settings(): + with patch("crewai_cli.organization.main.Settings") as mock_settings_class: + mock_settings_instance = MagicMock() + mock_settings_class.return_value = mock_settings_instance + yield mock_settings_instance + + +@patch("crewai_cli.cli.OrganizationCommand") +def test_org_list_command(mock_org_command_class, runner): + mock_org_instance = MagicMock() + mock_org_command_class.return_value = mock_org_instance + + result = runner.invoke(org_list) + + assert result.exit_code == 0 + mock_org_command_class.assert_called_once() + mock_org_instance.list.assert_called_once() + + +@patch("crewai_cli.cli.OrganizationCommand") +def test_org_switch_command(mock_org_command_class, runner): + mock_org_instance = MagicMock() + mock_org_command_class.return_value = mock_org_instance + + result = runner.invoke(switch, ["test-id"]) + + assert result.exit_code == 0 + mock_org_command_class.assert_called_once() + mock_org_instance.switch.assert_called_once_with("test-id") + + +@patch("crewai_cli.cli.OrganizationCommand") +def test_org_current_command(mock_org_command_class, runner): + mock_org_instance = MagicMock() + mock_org_command_class.return_value = mock_org_instance + + result = runner.invoke(current) + + assert result.exit_code == 0 + mock_org_command_class.assert_called_once() + mock_org_instance.current.assert_called_once() + + +class TestOrganizationCommand(unittest.TestCase): + def setUp(self): + with patch.object(OrganizationCommand, "__init__", return_value=None): + self.org_command = OrganizationCommand() + self.org_command.plus_api_client = MagicMock() + + @patch("crewai_cli.organization.main.console") + @patch("crewai_cli.organization.main.Table") + def test_list_organizations_success(self, mock_table, mock_console): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = [ + {"name": "Org 1", "uuid": "org-123"}, + {"name": "Org 2", "uuid": "org-456"}, + ] + self.org_command.plus_api_client = MagicMock() + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + mock_console.print = MagicMock() + + self.org_command.list() + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_table.assert_called_once_with(title="Your Organizations") + mock_table.return_value.add_column.assert_has_calls( + [call("Name", style="cyan"), call("ID", style="green")] + ) + mock_table.return_value.add_row.assert_has_calls( + [call("Org 1", "org-123"), call("Org 2", "org-456")] + ) + + @patch("crewai_cli.organization.main.console") + def test_list_organizations_empty(self, mock_console): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = [] + self.org_command.plus_api_client = MagicMock() + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + self.org_command.list() + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_console.print.assert_called_once_with( + "You don't belong to any organizations yet.", style="yellow" + ) + + @patch("crewai_cli.organization.main.console") + def test_list_organizations_api_error(self, mock_console): + self.org_command.plus_api_client = MagicMock() + self.org_command.plus_api_client.get_organizations.side_effect = ( + httpx.HTTPError("API Error") + ) + + with pytest.raises(SystemExit): + self.org_command.list() + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_console.print.assert_called_once_with( + "Failed to retrieve organization list: API Error", style="bold red" + ) + + @patch("crewai_cli.organization.main.console") + @patch("crewai_cli.organization.main.Settings") + def test_switch_organization_success(self, mock_settings_class, mock_console): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = [ + {"name": "Org 1", "uuid": "org-123"}, + {"name": "Test Org", "uuid": "test-id"}, + ] + self.org_command.plus_api_client = MagicMock() + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + mock_settings_instance = MagicMock() + mock_settings_class.return_value = mock_settings_instance + + self.org_command.switch("test-id") + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_settings_instance.dump.assert_called_once() + assert mock_settings_instance.org_name == "Test Org" + assert mock_settings_instance.org_uuid == "test-id" + mock_console.print.assert_called_once_with( + "Successfully switched to Test Org (test-id)", style="bold green" + ) + + @patch("crewai_cli.organization.main.console") + def test_switch_organization_not_found(self, mock_console): + mock_response = MagicMock() + mock_response.raise_for_status = MagicMock() + mock_response.json.return_value = [ + {"name": "Org 1", "uuid": "org-123"}, + {"name": "Org 2", "uuid": "org-456"}, + ] + self.org_command.plus_api_client = MagicMock() + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + self.org_command.switch("non-existent-id") + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_console.print.assert_called_once_with( + "Organization with id 'non-existent-id' not found.", style="bold red" + ) + + @patch("crewai_cli.organization.main.console") + @patch("crewai_cli.organization.main.Settings") + def test_current_organization_with_org(self, mock_settings_class, mock_console): + mock_settings_instance = MagicMock() + mock_settings_instance.org_name = "Test Org" + mock_settings_instance.org_uuid = "test-id" + mock_settings_class.return_value = mock_settings_instance + + self.org_command.current() + + self.org_command.plus_api_client.get_organizations.assert_not_called() + mock_console.print.assert_called_once_with( + "Currently logged in to organization Test Org (test-id)", style="bold green" + ) + + @patch("crewai_cli.organization.main.console") + @patch("crewai_cli.organization.main.Settings") + def test_current_organization_without_org(self, mock_settings_class, mock_console): + mock_settings_instance = MagicMock() + mock_settings_instance.org_uuid = None + mock_settings_class.return_value = mock_settings_instance + + self.org_command.current() + + assert mock_console.print.call_count == 3 + mock_console.print.assert_any_call( + "You're not currently logged in to any organization.", style="yellow" + ) + + @patch("crewai_cli.organization.main.console") + def test_list_organizations_unauthorized(self, mock_console): + mock_response = MagicMock() + mock_http_error = httpx.HTTPStatusError( + "401 Client Error: Unauthorized", + request=httpx.Request("GET", "http://test"), + response=httpx.Response(401), + ) + + mock_response.raise_for_status.side_effect = mock_http_error + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + self.org_command.list() + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_console.print.assert_called_once_with( + "You are not logged in to any organization. Use 'crewai login' to login.", + style="bold red", + ) + + @patch("crewai_cli.organization.main.console") + def test_switch_organization_unauthorized(self, mock_console): + mock_response = MagicMock() + mock_http_error = httpx.HTTPStatusError( + "401 Client Error: Unauthorized", + request=httpx.Request("GET", "http://test"), + response=httpx.Response(401), + ) + + mock_response.raise_for_status.side_effect = mock_http_error + self.org_command.plus_api_client.get_organizations.return_value = mock_response + + self.org_command.switch("test-id") + + self.org_command.plus_api_client.get_organizations.assert_called_once() + mock_console.print.assert_called_once_with( + "You are not logged in to any organization. Use 'crewai login' to login.", + style="bold red", + ) diff --git a/lib/cli/tests/test_cli.py b/lib/cli/tests/test_cli.py new file mode 100644 index 000000000..ab6eef849 --- /dev/null +++ b/lib/cli/tests/test_cli.py @@ -0,0 +1,255 @@ +from pathlib import Path +from unittest import mock + +import pytest +from click.testing import CliRunner +from crewai_cli.cli import ( + deploy_create, + deploy_list, + deploy_logs, + deploy_push, + deploy_remove, + deply_status, + flow_add_crew, + login, + reset_memories, + test, + train, + version, +) + + +@pytest.fixture +def runner(): + return CliRunner() + + +@mock.patch("crewai_cli.cli.train_crew") +def test_train_default_iterations(train_crew, runner): + result = runner.invoke(train) + + train_crew.assert_called_once_with(5, "trained_agents_data.pkl") + assert result.exit_code == 0 + assert "Training the Crew for 5 iterations" in result.output + + +@mock.patch("crewai_cli.cli.train_crew") +def test_train_custom_iterations(train_crew, runner): + result = runner.invoke(train, ["--n_iterations", "10"]) + + train_crew.assert_called_once_with(10, "trained_agents_data.pkl") + assert result.exit_code == 0 + assert "Training the Crew for 10 iterations" in result.output + + +@mock.patch("crewai_cli.cli.train_crew") +def test_train_invalid_string_iterations(train_crew, runner): + result = runner.invoke(train, ["--n_iterations", "invalid"]) + + train_crew.assert_not_called() + assert result.exit_code == 2 + assert ( + "Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n" + in result.output + ) + + +def test_reset_no_memory_flags(runner): + result = runner.invoke( + reset_memories, + ) + assert ( + result.output + == "Please specify at least one memory type to reset using the appropriate flags.\n" + ) + + +def test_version_flag(runner): + result = runner.invoke(version) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + + +def test_version_command(runner): + result = runner.invoke(version) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + + +def test_version_command_with_tools(runner): + result = runner.invoke(version, ["--tools"]) + + assert result.exit_code == 0 + assert "crewai version:" in result.output + assert ( + "crewai tools version:" in result.output + or "crewai tools not installed" in result.output + ) + + +@mock.patch("crewai_cli.cli.evaluate_crew") +def test_test_default_iterations(evaluate_crew, runner): + result = runner.invoke(test) + + evaluate_crew.assert_called_once_with(3, "gpt-4o-mini") + assert result.exit_code == 0 + assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output + + +@mock.patch("crewai_cli.cli.evaluate_crew") +def test_test_custom_iterations(evaluate_crew, runner): + result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"]) + + evaluate_crew.assert_called_once_with(5, "gpt-4o") + assert result.exit_code == 0 + assert "Testing the crew for 5 iterations with model gpt-4o" in result.output + + +@mock.patch("crewai_cli.cli.evaluate_crew") +def test_test_invalid_string_iterations(evaluate_crew, runner): + result = runner.invoke(test, ["--n_iterations", "invalid"]) + + evaluate_crew.assert_not_called() + assert result.exit_code == 2 + assert ( + "Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n" + in result.output + ) + + +@mock.patch("crewai_cli.cli.AuthenticationCommand") +def test_login(command, runner): + mock_auth = command.return_value + result = runner.invoke(login) + + assert result.exit_code == 0 + mock_auth.login.assert_called_once() + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_create(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_create) + + assert result.exit_code == 0 + mock_deploy.create_crew.assert_called_once() + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_list(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_list) + + assert result.exit_code == 0 + mock_deploy.list_crews.assert_called_once() + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_push(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_push, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.deploy.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_push_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_push) + + assert result.exit_code == 0 + mock_deploy.deploy.assert_called_once_with(uuid=None) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_status(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deply_status, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_status_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deply_status) + + assert result.exit_code == 0 + mock_deploy.get_crew_status.assert_called_once_with(uuid=None) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_logs(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_logs, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_logs_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_logs) + + assert result.exit_code == 0 + mock_deploy.get_crew_logs.assert_called_once_with(uuid=None) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_remove(command, runner): + mock_deploy = command.return_value + uuid = "test-uuid" + result = runner.invoke(deploy_remove, ["-u", uuid]) + + assert result.exit_code == 0 + mock_deploy.remove_crew.assert_called_once_with(uuid=uuid) + + +@mock.patch("crewai_cli.cli.DeployCommand") +def test_deploy_remove_no_uuid(command, runner): + mock_deploy = command.return_value + result = runner.invoke(deploy_remove) + + assert result.exit_code == 0 + mock_deploy.remove_crew.assert_called_once_with(uuid=None) + + +@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew") +@mock.patch("pathlib.Path.exists", return_value=True) +def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner): + crew_name = "new_crew" + result = runner.invoke(flow_add_crew, [crew_name]) + + assert result.exit_code == 0, f"Command failed with output: {result.output}" + assert f"Adding crew {crew_name} to the flow" in result.output + + mock_create_embedded_crew.assert_called_once() + call_args, call_kwargs = mock_create_embedded_crew.call_args + assert call_args[0] == crew_name + assert "parent_folder" in call_kwargs + assert isinstance(call_kwargs["parent_folder"], Path) + + +def test_add_crew_to_flow_not_in_root(runner): + with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists: + def exists_side_effect(self): + if self.name == "pyproject.toml": + return False + return True + + mock_exists.side_effect = exists_side_effect + + result = runner.invoke(flow_add_crew, ["new_crew"]) + + assert result.exit_code != 0 + assert "This command must be run from the root of a flow project." in str( + result.output + ) diff --git a/lib/cli/tests/test_config.py b/lib/cli/tests/test_config.py new file mode 100644 index 000000000..46b4a6c81 --- /dev/null +++ b/lib/cli/tests/test_config.py @@ -0,0 +1,148 @@ +import json +import shutil +import tempfile +import unittest +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import MagicMock, patch + +from crewai_cli.config import ( + CLI_SETTINGS_KEYS, + DEFAULT_CLI_SETTINGS, + USER_SETTINGS_KEYS, + Settings, +) +from crewai_cli.shared.token_manager import TokenManager + + +class TestSettings(unittest.TestCase): + def setUp(self): + self.test_dir = Path(tempfile.mkdtemp()) + self.config_path = self.test_dir / "settings.json" + + def tearDown(self): + shutil.rmtree(self.test_dir) + + def test_empty_initialization(self): + settings = Settings(config_path=self.config_path) + self.assertIsNone(settings.tool_repository_username) + self.assertIsNone(settings.tool_repository_password) + + def test_initialization_with_data(self): + settings = Settings( + config_path=self.config_path, tool_repository_username="user1" + ) + self.assertEqual(settings.tool_repository_username, "user1") + self.assertIsNone(settings.tool_repository_password) + + def test_initialization_with_existing_file(self): + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with self.config_path.open("w") as f: + json.dump({"tool_repository_username": "file_user"}, f) + + settings = Settings(config_path=self.config_path) + self.assertEqual(settings.tool_repository_username, "file_user") + + def test_merge_file_and_input_data(self): + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with self.config_path.open("w") as f: + json.dump( + { + "tool_repository_username": "file_user", + "tool_repository_password": "file_pass", + }, + f, + ) + + settings = Settings( + config_path=self.config_path, tool_repository_username="new_user" + ) + self.assertEqual(settings.tool_repository_username, "new_user") + self.assertEqual(settings.tool_repository_password, "file_pass") + + def test_clear_user_settings(self): + user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} + + settings = Settings(config_path=self.config_path, **user_settings) + settings.clear_user_settings() + + for key in user_settings.keys(): + self.assertEqual(getattr(settings, key), None) + + @patch("crewai_cli.config.TokenManager") + def test_reset_settings(self, mock_token_manager): + user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS} + cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"} + cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"} + + settings = Settings( + config_path=self.config_path, **user_settings, **cli_settings + ) + + mock_token_manager.return_value = MagicMock() + TokenManager().save_tokens( + "aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp() + ) + + settings.reset() + + for key in user_settings.keys(): + self.assertEqual(getattr(settings, key), None) + for key in cli_settings.keys(): + self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key)) + + mock_token_manager.return_value.clear_tokens.assert_called_once() + + def test_dump_new_settings(self): + settings = Settings( + config_path=self.config_path, tool_repository_username="user1" + ) + settings.dump() + + with self.config_path.open("r") as f: + saved_data = json.load(f) + + self.assertEqual(saved_data["tool_repository_username"], "user1") + + def test_update_existing_settings(self): + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with self.config_path.open("w") as f: + json.dump({"existing_setting": "value"}, f) + + settings = Settings( + config_path=self.config_path, tool_repository_username="user1" + ) + settings.dump() + + with self.config_path.open("r") as f: + saved_data = json.load(f) + + self.assertEqual(saved_data["existing_setting"], "value") + self.assertEqual(saved_data["tool_repository_username"], "user1") + + def test_none_values(self): + settings = Settings(config_path=self.config_path, tool_repository_username=None) + settings.dump() + + with self.config_path.open("r") as f: + saved_data = json.load(f) + + self.assertIsNone(saved_data.get("tool_repository_username")) + + def test_invalid_json_in_config(self): + self.config_path.parent.mkdir(parents=True, exist_ok=True) + with self.config_path.open("w") as f: + f.write("invalid json") + + try: + settings = Settings(config_path=self.config_path) + self.assertIsNone(settings.tool_repository_username) + except json.JSONDecodeError: + self.fail("Settings initialization should handle invalid JSON") + + def test_empty_config_file(self): + self.config_path.parent.mkdir(parents=True, exist_ok=True) + self.config_path.touch() + + settings = Settings(config_path=self.config_path) + self.assertIsNone(settings.tool_repository_username) diff --git a/lib/cli/tests/test_constants.py b/lib/cli/tests/test_constants.py new file mode 100644 index 000000000..527ae1dec --- /dev/null +++ b/lib/cli/tests/test_constants.py @@ -0,0 +1,20 @@ +from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS + + +def test_huggingface_in_providers(): + """Test that Huggingface is in the PROVIDERS list.""" + assert "huggingface" in PROVIDERS + + +def test_huggingface_env_vars(): + """Test that Huggingface environment variables are properly configured.""" + assert "huggingface" in ENV_VARS + assert any( + detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"] + ) + + +def test_huggingface_models(): + """Test that Huggingface models are properly configured.""" + assert "huggingface" in MODELS + assert len(MODELS["huggingface"]) > 0 diff --git a/lib/cli/tests/test_create_crew.py b/lib/cli/tests/test_create_crew.py new file mode 100644 index 000000000..83fdbbeeb --- /dev/null +++ b/lib/cli/tests/test_create_crew.py @@ -0,0 +1,347 @@ +import keyword +import shutil +import tempfile +from pathlib import Path +from unittest import mock + +import pytest +from click.testing import CliRunner +from crewai_cli.create_crew import create_crew, create_folder_structure + + +@pytest.fixture +def runner(): + return CliRunner() + + +@pytest.fixture +def temp_dir(): + temp_path = tempfile.mkdtemp() + yield temp_path + shutil.rmtree(temp_path) + + +def test_create_folder_structure_strips_single_trailing_slash(): + with tempfile.TemporaryDirectory() as temp_dir: + folder_path, folder_name, class_name = create_folder_structure( + "hello/", parent_folder=temp_dir + ) + + assert folder_name == "hello" + assert class_name == "Hello" + assert folder_path.name == "hello" + assert folder_path.exists() + assert folder_path.parent == Path(temp_dir) + + +def test_create_folder_structure_strips_multiple_trailing_slashes(): + with tempfile.TemporaryDirectory() as temp_dir: + folder_path, folder_name, class_name = create_folder_structure( + "hello///", parent_folder=temp_dir + ) + + assert folder_name == "hello" + assert class_name == "Hello" + assert folder_path.name == "hello" + assert folder_path.exists() + assert folder_path.parent == Path(temp_dir) + + +def test_create_folder_structure_handles_complex_name_with_trailing_slash(): + with tempfile.TemporaryDirectory() as temp_dir: + folder_path, folder_name, class_name = create_folder_structure( + "my-awesome_project/", parent_folder=temp_dir + ) + + assert folder_name == "my_awesome_project" + assert class_name == "MyAwesomeProject" + assert folder_path.name == "my_awesome_project" + assert folder_path.exists() + assert folder_path.parent == Path(temp_dir) + + +def test_create_folder_structure_normal_name_unchanged(): + with tempfile.TemporaryDirectory() as temp_dir: + folder_path, folder_name, class_name = create_folder_structure( + "hello", parent_folder=temp_dir + ) + + assert folder_name == "hello" + assert class_name == "Hello" + assert folder_path.name == "hello" + assert folder_path.exists() + assert folder_path.parent == Path(temp_dir) + + +def test_create_folder_structure_with_parent_folder(): + with tempfile.TemporaryDirectory() as temp_dir: + parent_path = Path(temp_dir) / "parent" + parent_path.mkdir() + + folder_path, folder_name, class_name = create_folder_structure( + "child/", parent_folder=parent_path + ) + + assert folder_name == "child" + assert class_name == "Child" + assert folder_path.name == "child" + assert folder_path.parent == parent_path + assert folder_path.exists() + + +@mock.patch("crewai_cli.create_crew.copy_template") +@mock.patch("crewai_cli.create_crew.write_env_file") +@mock.patch("crewai_cli.create_crew.load_env_vars") +def test_create_crew_with_trailing_slash_creates_valid_project( + mock_load_env, mock_write_env, mock_copy_template, temp_dir +): + mock_load_env.return_value = {} + + with tempfile.TemporaryDirectory() as work_dir: + with mock.patch( + "crewai_cli.create_crew.create_folder_structure" + ) as mock_create_folder: + mock_folder_path = Path(work_dir) / "test_project" + mock_create_folder.return_value = ( + mock_folder_path, + "test_project", + "TestProject", + ) + + create_crew("test-project/", skip_provider=True) + + mock_create_folder.assert_called_once_with("test-project/", None) + mock_copy_template.assert_called() + copy_calls = mock_copy_template.call_args_list + + for call in copy_calls: + args = call[0] + if len(args) >= 5: + folder_name_arg = args[4] + assert not folder_name_arg.endswith("/"), ( + f"folder_name should not end with slash: {folder_name_arg}" + ) + + +@mock.patch("crewai_cli.create_crew.copy_template") +@mock.patch("crewai_cli.create_crew.write_env_file") +@mock.patch("crewai_cli.create_crew.load_env_vars") +def test_create_crew_with_multiple_trailing_slashes( + mock_load_env, mock_write_env, mock_copy_template, temp_dir +): + mock_load_env.return_value = {} + + with tempfile.TemporaryDirectory() as work_dir: + with mock.patch( + "crewai_cli.create_crew.create_folder_structure" + ) as mock_create_folder: + mock_folder_path = Path(work_dir) / "test_project" + mock_create_folder.return_value = ( + mock_folder_path, + "test_project", + "TestProject", + ) + + create_crew("test-project///", skip_provider=True) + + mock_create_folder.assert_called_once_with("test-project///", None) + + +@mock.patch("crewai_cli.create_crew.copy_template") +@mock.patch("crewai_cli.create_crew.write_env_file") +@mock.patch("crewai_cli.create_crew.load_env_vars") +def test_create_crew_normal_name_still_works( + mock_load_env, mock_write_env, mock_copy_template, temp_dir +): + mock_load_env.return_value = {} + + with tempfile.TemporaryDirectory() as work_dir: + with mock.patch( + "crewai_cli.create_crew.create_folder_structure" + ) as mock_create_folder: + mock_folder_path = Path(work_dir) / "normal_project" + mock_create_folder.return_value = ( + mock_folder_path, + "normal_project", + "NormalProject", + ) + + create_crew("normal-project", skip_provider=True) + + mock_create_folder.assert_called_once_with("normal-project", None) + + +def test_create_folder_structure_handles_spaces_and_dashes_with_slash(): + with tempfile.TemporaryDirectory() as temp_dir: + folder_path, folder_name, class_name = create_folder_structure( + "My Cool-Project/", parent_folder=temp_dir + ) + + assert folder_name == "my_cool_project" + assert class_name == "MyCoolProject" + assert folder_path.name == "my_cool_project" + assert folder_path.exists() + assert folder_path.parent == Path(temp_dir) + + +def test_create_folder_structure_raises_error_for_invalid_names(): + with tempfile.TemporaryDirectory() as temp_dir: + invalid_cases = [ + ("123project/", "cannot start with a digit"), + ("True/", "reserved Python keyword"), + ("False/", "reserved Python keyword"), + ("None/", "reserved Python keyword"), + ("class/", "reserved Python keyword"), + ("def/", "reserved Python keyword"), + (" /", "empty or contain only whitespace"), + ("", "empty or contain only whitespace"), + ("@#$/", "contains no valid characters"), + ] + + for invalid_name, expected_error in invalid_cases: + with pytest.raises(ValueError, match=expected_error): + create_folder_structure(invalid_name, parent_folder=temp_dir) + + +def test_create_folder_structure_validates_names(): + with tempfile.TemporaryDirectory() as temp_dir: + valid_cases = [ + ("hello/", "hello", "Hello"), + ("my-project/", "my_project", "MyProject"), + ("hello_world/", "hello_world", "HelloWorld"), + ("valid123/", "valid123", "Valid123"), + ("hello.world/", "helloworld", "HelloWorld"), + ("hello@world/", "helloworld", "HelloWorld"), + ] + + for valid_name, expected_folder, expected_class in valid_cases: + folder_path, folder_name, class_name = create_folder_structure( + valid_name, parent_folder=temp_dir + ) + assert folder_name == expected_folder + assert class_name == expected_class + + assert folder_name.isidentifier(), ( + f"folder_name '{folder_name}' should be valid Python identifier" + ) + assert not keyword.iskeyword(folder_name), ( + f"folder_name '{folder_name}' should not be Python keyword" + ) + assert not folder_name[0].isdigit(), ( + f"folder_name '{folder_name}' should not start with digit" + ) + + assert class_name.isidentifier(), ( + f"class_name '{class_name}' should be valid Python identifier" + ) + assert not keyword.iskeyword(class_name), ( + f"class_name '{class_name}' should not be Python keyword" + ) + assert folder_path.parent == Path(temp_dir) + + if folder_path.exists(): + shutil.rmtree(folder_path) + + +@mock.patch("crewai_cli.create_crew.copy_template") +@mock.patch("crewai_cli.create_crew.write_env_file") +@mock.patch("crewai_cli.create_crew.load_env_vars") +def test_create_crew_with_parent_folder_and_trailing_slash( + mock_load_env, mock_write_env, mock_copy_template, temp_dir +): + mock_load_env.return_value = {} + + with tempfile.TemporaryDirectory() as work_dir: + parent_path = Path(work_dir) / "parent" + parent_path.mkdir() + + create_crew("child-crew/", skip_provider=True, parent_folder=parent_path) + + crew_path = parent_path / "child_crew" + assert crew_path.exists() + assert not (crew_path / "src").exists() + + +def test_create_folder_structure_folder_name_validation(): + """Test that folder names are validated as valid Python module names""" + with tempfile.TemporaryDirectory() as temp_dir: + folder_invalid_cases = [ + ("123invalid/", "cannot start with a digit.*invalid Python module name"), + ("import/", "reserved Python keyword"), + ("class/", "reserved Python keyword"), + ("for/", "reserved Python keyword"), + ("@#$invalid/", "contains no valid characters.*Python module name"), + ] + + for invalid_name, expected_error in folder_invalid_cases: + with pytest.raises(ValueError, match=expected_error): + create_folder_structure(invalid_name, parent_folder=temp_dir) + + valid_cases = [ + ("hello-world/", "hello_world"), + ("my.project/", "myproject"), + ("test@123/", "test123"), + ("valid_name/", "valid_name"), + ] + + for valid_name, expected_folder in valid_cases: + folder_path, folder_name, class_name = create_folder_structure( + valid_name, parent_folder=temp_dir + ) + assert folder_name == expected_folder + assert folder_name.isidentifier() + assert not keyword.iskeyword(folder_name) + + if folder_path.exists(): + shutil.rmtree(folder_path) + + +def test_create_folder_structure_rejects_reserved_names(): + """Test that reserved script names are rejected to prevent pyproject.toml conflicts.""" + with tempfile.TemporaryDirectory() as temp_dir: + reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"] + + for reserved_name in reserved_names: + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(reserved_name, parent_folder=temp_dir) + + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir) + + capitalized = reserved_name.capitalize() + with pytest.raises(ValueError, match="which is reserved"): + create_folder_structure(capitalized, parent_folder=temp_dir) + + +@mock.patch("crewai_cli.create_crew.create_folder_structure") +@mock.patch("crewai_cli.create_crew.copy_template") +@mock.patch("crewai_cli.create_crew.load_env_vars") +@mock.patch("crewai_cli.create_crew.get_provider_data") +@mock.patch("crewai_cli.create_crew.select_provider") +@mock.patch("crewai_cli.create_crew.select_model") +@mock.patch("click.prompt") +def test_env_vars_are_uppercased_in_env_file( + mock_prompt, + mock_select_model, + mock_select_provider, + mock_get_provider_data, + mock_load_env_vars, + mock_copy_template, + mock_create_folder_structure, + tmp_path, +): + crew_path = tmp_path / "test_crew" + crew_path.mkdir() + mock_create_folder_structure.return_value = (crew_path, "test_crew", "TestCrew") + + mock_load_env_vars.return_value = {} + mock_get_provider_data.return_value = {"openai": ["gpt-4"]} + mock_select_provider.return_value = "azure" + mock_select_model.return_value = "azure/openai" + mock_prompt.return_value = "fake-api-key" + + create_crew("Test Crew") + + env_file_path = crew_path / ".env" + content = env_file_path.read_text() + assert "MODEL=" in content diff --git a/lib/cli/tests/test_crew_test.py b/lib/cli/tests/test_crew_test.py new file mode 100644 index 000000000..725a1b945 --- /dev/null +++ b/lib/cli/tests/test_crew_test.py @@ -0,0 +1,97 @@ +import subprocess +from unittest import mock + +import pytest + +from crewai_cli import evaluate_crew + + +@pytest.mark.parametrize( + "n_iterations,model", + [ + (1, "gpt-4o"), + (5, "gpt-3.5-turbo"), + (10, "gpt-4"), + ], +) +@mock.patch("crewai_cli.evaluate_crew.subprocess.run") +def test_crew_success(mock_subprocess_run, n_iterations, model): + """Test the crew function for successful execution.""" + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=f"uv run test {n_iterations} {model}", returncode=0 + ) + result = evaluate_crew.evaluate_crew(n_iterations, model) + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "test", str(n_iterations), model], + capture_output=False, + text=True, + check=True, + ) + assert result is None + + +@mock.patch("crewai_cli.evaluate_crew.click") +def test_test_crew_zero_iterations(click): + evaluate_crew.evaluate_crew(0, "gpt-4o") + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai_cli.evaluate_crew.click") +def test_test_crew_negative_iterations(click): + evaluate_crew.evaluate_crew(-2, "gpt-4o") + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai_cli.evaluate_crew.click") +@mock.patch("crewai_cli.evaluate_crew.subprocess.run") +def test_test_crew_called_process_error(mock_subprocess_run, click): + n_iterations = 5 + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, + cmd=["uv", "run", "test", str(n_iterations), "gpt-4o"], + output="Error", + stderr="Some error occurred", + ) + evaluate_crew.evaluate_crew(n_iterations, "gpt-4o") + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "test", "5", "gpt-4o"], + capture_output=False, + text=True, + check=True, + ) + click.echo.assert_has_calls( + [ + mock.call.echo( + "An error occurred while testing the crew: Command '['uv', 'run', 'test', '5', 'gpt-4o']' returned non-zero exit status 1.", + err=True, + ), + mock.call.echo("Error", err=True), + ] + ) + + +@mock.patch("crewai_cli.evaluate_crew.click") +@mock.patch("crewai_cli.evaluate_crew.subprocess.run") +def test_test_crew_unexpected_exception(mock_subprocess_run, click): + # Arrange + n_iterations = 5 + mock_subprocess_run.side_effect = Exception("Unexpected error") + evaluate_crew.evaluate_crew(n_iterations, "gpt-4o") + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "test", "5", "gpt-4o"], + capture_output=False, + text=True, + check=True, + ) + click.echo.assert_called_once_with( + "An unexpected error occurred: Unexpected error", err=True + ) diff --git a/lib/cli/tests/test_git.py b/lib/cli/tests/test_git.py new file mode 100644 index 000000000..c6644990b --- /dev/null +++ b/lib/cli/tests/test_git.py @@ -0,0 +1,101 @@ +import pytest +from crewai_cli.git import Repository + + +@pytest.fixture() +def repository(fp): + fp.register(["git", "--version"], stdout="git version 2.30.0\n") + fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n") + fp.register(["git", "fetch"], stdout="") + return Repository(path=".") + + +def test_init_with_invalid_git_repo(fp): + fp.register(["git", "--version"], stdout="git version 2.30.0\n") + fp.register( + ["git", "rev-parse", "--is-inside-work-tree"], + returncode=1, + stderr="fatal: not a git repository\n", + ) + + with pytest.raises(ValueError): + Repository(path="invalid/path") + + +def test_is_git_not_installed(fp): + fp.register(["git", "--version"], returncode=1) + + with pytest.raises( + ValueError, match="Git is not installed or not found in your PATH." + ): + Repository(path=".") + + +def test_status(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main [ahead 1]\n", + ) + assert repository.status() == "## main...origin/main [ahead 1]" + + +def test_has_uncommitted_changes(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main\n M somefile.txt\n", + ) + assert repository.has_uncommitted_changes() is True + + +def test_is_ahead_or_behind(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main [ahead 1]\n", + ) + assert repository.is_ahead_or_behind() is True + + +def test_is_synced_when_synced(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n" + ) + fp.register( + ["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n" + ) + assert repository.is_synced() is True + + +def test_is_synced_with_uncommitted_changes(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main\n M somefile.txt\n", + ) + assert repository.is_synced() is False + + +def test_is_synced_when_ahead_or_behind(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main [ahead 1]\n", + ) + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main [ahead 1]\n", + ) + assert repository.is_synced() is False + + +def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository): + fp.register( + ["git", "status", "--branch", "--porcelain"], + stdout="## main...origin/main [ahead 1]\n M somefile.txt\n", + ) + assert repository.is_synced() is False + + +def test_origin_url(fp, repository): + fp.register( + ["git", "remote", "get-url", "origin"], + stdout="https://github.com/user/repo.git\n", + ) + assert repository.origin_url() == "https://github.com/user/repo.git" diff --git a/lib/cli/tests/test_plus_api.py b/lib/cli/tests/test_plus_api.py new file mode 100644 index 000000000..75fe08a3b --- /dev/null +++ b/lib/cli/tests/test_plus_api.py @@ -0,0 +1,356 @@ +import os +import unittest +from unittest.mock import ANY, AsyncMock, MagicMock, patch + +import pytest + +from crewai_cli.plus_api import PlusAPI + + +class TestPlusAPI(unittest.TestCase): + def setUp(self): + self.api_key = "test_api_key" + self.api = PlusAPI(self.api_key) + self.org_uuid = "test-org-uuid" + + def test_init(self): + self.assertEqual(self.api.api_key, self.api_key) + self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}") + self.assertEqual(self.api.headers["Content-Type"], "application/json") + self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"]) + self.assertTrue(self.api.headers["X-Crewai-Version"]) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_login_to_tool_repository(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + + response = self.api.login_to_tool_repository() + + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/tools/login" + ) + self.assertEqual(response, mock_response) + + def assert_request_with_org_id( + self, mock_client_instance, method: str, endpoint: str, **kwargs + ): + mock_client_instance.request.assert_called_once_with( + method, + f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}", + headers={ + "Authorization": ANY, + "Content-Type": ANY, + "User-Agent": ANY, + "X-Crewai-Version": ANY, + "X-Crewai-Organization-Id": self.org_uuid, + }, + **kwargs, + ) + + @patch("crewai_cli.plus_api.Settings") + @patch("crewai_cli.plus_api.httpx.Client") + def test_login_to_tool_repository_with_org_uuid( + self, mock_client_class, mock_settings_class + ): + mock_settings = MagicMock() + mock_settings.org_uuid = self.org_uuid + mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL') + mock_settings_class.return_value = mock_settings + self.api = PlusAPI(self.api_key) + + mock_client_instance = MagicMock() + mock_response = MagicMock() + mock_client_instance.request.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client_instance + + response = self.api.login_to_tool_repository() + + self.assert_request_with_org_id( + mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login" + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_get_tool(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + + response = self.api.get_tool("test_tool_handle") + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/tools/test_tool_handle" + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.Settings") + @patch("crewai_cli.plus_api.httpx.Client") + def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class): + mock_settings = MagicMock() + mock_settings.org_uuid = self.org_uuid + mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL') + mock_settings_class.return_value = mock_settings + self.api = PlusAPI(self.api_key) + + mock_client_instance = MagicMock() + mock_response = MagicMock() + mock_client_instance.request.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client_instance + + response = self.api.get_tool("test_tool_handle") + + self.assert_request_with_org_id( + mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle" + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_publish_tool(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + handle = "test_tool_handle" + public = True + version = "1.0.0" + description = "Test tool description" + encoded_file = "encoded_test_file" + + response = self.api.publish_tool( + handle, public, version, description, encoded_file + ) + + params = { + "handle": handle, + "public": public, + "version": version, + "file": encoded_file, + "description": description, + "available_exports": None, + } + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/tools", json=params + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.Settings") + @patch("crewai_cli.plus_api.httpx.Client") + def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class): + mock_settings = MagicMock() + mock_settings.org_uuid = self.org_uuid + mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL') + mock_settings_class.return_value = mock_settings + self.api = PlusAPI(self.api_key) + + mock_client_instance = MagicMock() + mock_response = MagicMock() + mock_client_instance.request.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client_instance + + handle = "test_tool_handle" + public = True + version = "1.0.0" + description = "Test tool description" + encoded_file = "encoded_test_file" + + response = self.api.publish_tool( + handle, public, version, description, encoded_file + ) + + expected_params = { + "handle": handle, + "public": public, + "version": version, + "file": encoded_file, + "description": description, + "available_exports": None, + } + + self.assert_request_with_org_id( + mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_publish_tool_without_description(self, mock_make_request): + mock_response = MagicMock() + mock_make_request.return_value = mock_response + handle = "test_tool_handle" + public = False + version = "2.0.0" + description = None + encoded_file = "encoded_test_file" + + response = self.api.publish_tool( + handle, public, version, description, encoded_file + ) + + params = { + "handle": handle, + "public": public, + "version": version, + "file": encoded_file, + "description": description, + "available_exports": None, + } + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/tools", json=params + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.httpx.Client") + def test_make_request(self, mock_client_class): + mock_client_instance = MagicMock() + mock_response = MagicMock() + mock_client_instance.request.return_value = mock_response + mock_client_class.return_value.__enter__.return_value = mock_client_instance + + response = self.api._make_request("GET", "test_endpoint") + + mock_client_class.assert_called_once_with(trust_env=False, verify=True) + mock_client_instance.request.assert_called_once_with( + "GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers + ) + self.assertEqual(response, mock_response) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_deploy_by_name(self, mock_make_request): + self.api.deploy_by_name("test_project") + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_deploy_by_uuid(self, mock_make_request): + self.api.deploy_by_uuid("test_uuid") + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/crews/test_uuid/deploy" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_crew_status_by_name(self, mock_make_request): + self.api.crew_status_by_name("test_project") + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/crews/by-name/test_project/status" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_crew_status_by_uuid(self, mock_make_request): + self.api.crew_status_by_uuid("test_uuid") + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/crews/test_uuid/status" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_crew_by_name(self, mock_make_request): + self.api.crew_by_name("test_project") + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment" + ) + + self.api.crew_by_name("test_project", "custom_log") + mock_make_request.assert_called_with( + "GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_crew_by_uuid(self, mock_make_request): + self.api.crew_by_uuid("test_uuid") + mock_make_request.assert_called_once_with( + "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment" + ) + + self.api.crew_by_uuid("test_uuid", "custom_log") + mock_make_request.assert_called_with( + "GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_delete_crew_by_name(self, mock_make_request): + self.api.delete_crew_by_name("test_project") + mock_make_request.assert_called_once_with( + "DELETE", "/crewai_plus/api/v1/crews/by-name/test_project" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_delete_crew_by_uuid(self, mock_make_request): + self.api.delete_crew_by_uuid("test_uuid") + mock_make_request.assert_called_once_with( + "DELETE", "/crewai_plus/api/v1/crews/test_uuid" + ) + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_list_crews(self, mock_make_request): + self.api.list_crews() + mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews") + + @patch("crewai_cli.plus_api.PlusAPI._make_request") + def test_create_crew(self, mock_make_request): + payload = {"name": "test_crew"} + self.api.create_crew(payload) + mock_make_request.assert_called_once_with( + "POST", "/crewai_plus/api/v1/crews", json=payload + ) + + @patch("crewai_cli.plus_api.Settings") + @patch.dict(os.environ, {"CREWAI_PLUS_URL": ""}) + def test_custom_base_url(self, mock_settings_class): + mock_settings = MagicMock() + mock_settings.enterprise_base_url = "https://custom-url.com/api" + mock_settings_class.return_value = mock_settings + custom_api = PlusAPI("test_key") + self.assertEqual( + custom_api.base_url, + "https://custom-url.com/api", + ) + + @patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"}) + def test_custom_base_url_from_env(self): + custom_api = PlusAPI("test_key") + self.assertEqual( + custom_api.base_url, + "https://custom-url-from-env.com", + ) + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +async def test_get_agent(mock_async_client_class): + api = PlusAPI("test_api_key") + mock_response = MagicMock() + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance + + response = await api.get_agent("test_agent_handle") + + mock_client_instance.get.assert_called_once_with( + f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle", + headers=api.headers, + ) + assert response == mock_response + + +@pytest.mark.asyncio +@patch("httpx.AsyncClient") +@patch("crewai_cli.plus_api.Settings") +async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class): + org_uuid = "test-org-uuid" + mock_settings = MagicMock() + mock_settings.org_uuid = org_uuid + mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL") + mock_settings_class.return_value = mock_settings + + api = PlusAPI("test_api_key") + + mock_response = MagicMock() + mock_client_instance = AsyncMock() + mock_client_instance.get.return_value = mock_response + mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance + + response = await api.get_agent("test_agent_handle") + + mock_client_instance.get.assert_called_once_with( + f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle", + headers=api.headers, + ) + assert "X-Crewai-Organization-Id" in api.headers + assert api.headers["X-Crewai-Organization-Id"] == org_uuid + assert response == mock_response diff --git a/lib/cli/tests/test_settings_command.py b/lib/cli/tests/test_settings_command.py new file mode 100644 index 000000000..c788ff453 --- /dev/null +++ b/lib/cli/tests/test_settings_command.py @@ -0,0 +1,90 @@ +import tempfile +import unittest +from pathlib import Path +from unittest.mock import patch, MagicMock, call + +from crewai_cli.settings.main import SettingsCommand +from crewai_cli.config import ( + Settings, + USER_SETTINGS_KEYS, + CLI_SETTINGS_KEYS, + DEFAULT_CLI_SETTINGS, + HIDDEN_SETTINGS_KEYS, + READONLY_SETTINGS_KEYS, +) +import shutil + + +class TestSettingsCommand(unittest.TestCase): + def setUp(self): + self.test_dir = Path(tempfile.mkdtemp()) + self.config_path = self.test_dir / "settings.json" + self.settings = Settings(config_path=self.config_path) + self.settings_command = SettingsCommand( + settings_kwargs={"config_path": self.config_path} + ) + + def tearDown(self): + shutil.rmtree(self.test_dir) + + @patch("crewai_cli.settings.main.console") + @patch("crewai_cli.settings.main.Table") + def test_list_settings(self, mock_table_class, mock_console): + mock_table_instance = MagicMock() + mock_table_class.return_value = mock_table_instance + + self.settings_command.list() + + # Tests that the table is created skipping hidden settings + mock_table_instance.add_row.assert_has_calls( + [ + call( + field_name, + getattr(self.settings, field_name) or "Not set", + field_info.description, + ) + for field_name, field_info in Settings.model_fields.items() + if field_name not in HIDDEN_SETTINGS_KEYS + ] + ) + + # Tests that the table is printed + mock_console.print.assert_called_once_with(mock_table_instance) + + def test_set_valid_keys(self): + valid_keys = Settings.model_fields.keys() - ( + READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS + ) + for key in valid_keys: + test_value = f"some_value_for_{key}" + self.settings_command.set(key, test_value) + self.assertEqual(getattr(self.settings_command.settings, key), test_value) + + def test_set_invalid_key(self): + with self.assertRaises(SystemExit): + self.settings_command.set("invalid_key", "value") + + def test_set_readonly_keys(self): + for key in READONLY_SETTINGS_KEYS: + with self.assertRaises(SystemExit): + self.settings_command.set(key, "some_readonly_key_value") + + def test_set_hidden_keys(self): + for key in HIDDEN_SETTINGS_KEYS: + with self.assertRaises(SystemExit): + self.settings_command.set(key, "some_hidden_key_value") + + def test_reset_all_settings(self): + for key in USER_SETTINGS_KEYS + CLI_SETTINGS_KEYS: + setattr(self.settings_command.settings, key, f"custom_value_for_{key}") + self.settings_command.settings.dump() + + self.settings_command.reset_all_settings() + + for key in USER_SETTINGS_KEYS: + self.assertEqual(getattr(self.settings_command.settings, key), None) + + for key in CLI_SETTINGS_KEYS: + self.assertEqual( + getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS.get(key) + ) diff --git a/lib/cli/tests/test_token_manager.py b/lib/cli/tests/test_token_manager.py new file mode 100644 index 000000000..1cc14abdf --- /dev/null +++ b/lib/cli/tests/test_token_manager.py @@ -0,0 +1,294 @@ +"""Tests for TokenManager with atomic file operations.""" + +import json +import os +import tempfile +import unittest +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import patch + +from cryptography.fernet import Fernet + +from crewai_cli.shared.token_manager import TokenManager + + +class TestTokenManager(unittest.TestCase): + """Test cases for TokenManager.""" + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None: + """Set up test fixtures.""" + mock_get_key.return_value = Fernet.generate_key() + self.token_manager = TokenManager() + + @patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file") + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_get_or_create_key_existing( + self, + mock_get_or_create: unittest.mock.MagicMock, + mock_read: unittest.mock.MagicMock, + ) -> None: + """Test that existing key is returned when present.""" + mock_key = Fernet.generate_key() + mock_get_or_create.return_value = mock_key + + token_manager = TokenManager() + result = token_manager.key + + self.assertEqual(result, mock_key) + + def test_get_or_create_key_new(self) -> None: + """Test that new key is created when none exists.""" + mock_key = Fernet.generate_key() + + with ( + patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read, + patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create, + patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate, + ): + result = self.token_manager._get_or_create_key() + + self.assertEqual(result, mock_key) + mock_read.assert_called_with("secret.key") + mock_generate.assert_called_once() + mock_atomic_create.assert_called_once_with("secret.key", mock_key) + + def test_get_or_create_key_race_condition(self) -> None: + """Test that another process's key is used when atomic create fails.""" + our_key = Fernet.generate_key() + their_key = Fernet.generate_key() + + with ( + patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read, + patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create, + patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key), + ): + result = self.token_manager._get_or_create_key() + + self.assertEqual(result, their_key) + self.assertEqual(mock_read.call_count, 2) + + @patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file") + def test_save_tokens( + self, mock_write: unittest.mock.MagicMock + ) -> None: + """Test saving tokens encrypts and writes atomically.""" + access_token = "test_token" + expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp()) + + self.token_manager.save_tokens(access_token, expires_at) + + mock_write.assert_called_once() + args = mock_write.call_args[0] + self.assertEqual(args[0], "tokens.enc") + decrypted_data = self.token_manager.fernet.decrypt(args[1]) + data = json.loads(decrypted_data) + self.assertEqual(data["access_token"], access_token) + expiration = datetime.fromisoformat(data["expiration"]) + self.assertEqual(expiration, datetime.fromtimestamp(expires_at)) + + @patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_valid( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test getting a valid non-expired token.""" + access_token = "test_token" + expiration = (datetime.now() + timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertEqual(result, access_token) + + @patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_expired( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test that expired token returns None.""" + access_token = "test_token" + expiration = (datetime.now() - timedelta(hours=1)).isoformat() + data = {"access_token": access_token, "expiration": expiration} + encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode()) + mock_read.return_value = encrypted_data + + result = self.token_manager.get_token() + + self.assertIsNone(result) + + @patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file") + def test_get_token_not_found( + self, mock_read: unittest.mock.MagicMock + ) -> None: + """Test that missing token file returns None.""" + mock_read.return_value = None + + result = self.token_manager.get_token() + + self.assertIsNone(result) + + @patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file") + def test_clear_tokens( + self, mock_delete: unittest.mock.MagicMock + ) -> None: + """Test clearing tokens deletes the token file.""" + self.token_manager.clear_tokens() + + mock_delete.assert_called_once_with("tokens.enc") + + +class TestAtomicFileOperations(unittest.TestCase): + """Test atomic file operations directly.""" + + def setUp(self) -> None: + """Set up test fixtures with temp directory.""" + self.temp_dir = tempfile.mkdtemp() + self.original_get_path = TokenManager._get_secure_storage_path + + # Patch to use temp directory + def mock_get_path() -> Path: + return Path(self.temp_dir) + + TokenManager._get_secure_storage_path = staticmethod(mock_get_path) + + def tearDown(self) -> None: + """Clean up temp directory.""" + TokenManager._get_secure_storage_path = staticmethod(self.original_get_path) + import shutil + shutil.rmtree(self.temp_dir, ignore_errors=True) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_create_new_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic create succeeds for new file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + result = tm._atomic_create_secure_file("test.txt", b"content") + + self.assertTrue(result) + file_path = Path(self.temp_dir) / "test.txt" + self.assertTrue(file_path.exists()) + self.assertEqual(file_path.read_bytes(), b"content") + self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_create_existing_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic create fails for existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + # Create file first + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"original") + + result = tm._atomic_create_secure_file("test.txt", b"new content") + + self.assertFalse(result) + self.assertEqual(file_path.read_bytes(), b"original") + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_new_file( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic write creates new file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + tm._atomic_write_secure_file("test.txt", b"content") + + file_path = Path(self.temp_dir) / "test.txt" + self.assertTrue(file_path.exists()) + self.assertEqual(file_path.read_bytes(), b"content") + self.assertEqual(file_path.stat().st_mode & 0o777, 0o600) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_overwrites( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test atomic write overwrites existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"original") + + tm._atomic_write_secure_file("test.txt", b"new content") + + self.assertEqual(file_path.read_bytes(), b"new content") + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_atomic_write_no_temp_file_on_success( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test that temp file is cleaned up after successful write.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + tm._atomic_write_secure_file("test.txt", b"content") + + # Check no temp files remain + temp_files = list(Path(self.temp_dir).glob(".test.txt.*")) + self.assertEqual(len(temp_files), 0) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_read_secure_file_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test reading existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"content") + + result = tm._read_secure_file("test.txt") + + self.assertEqual(result, b"content") + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_read_secure_file_not_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test reading non-existent file returns None.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + result = tm._read_secure_file("nonexistent.txt") + + self.assertIsNone(result) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_delete_secure_file_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test deleting existing file.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + file_path = Path(self.temp_dir) / "test.txt" + file_path.write_bytes(b"content") + + tm._delete_secure_file("test.txt") + + self.assertFalse(file_path.exists()) + + @patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key") + def test_delete_secure_file_not_exists( + self, mock_get_key: unittest.mock.MagicMock + ) -> None: + """Test deleting non-existent file doesn't raise.""" + mock_get_key.return_value = Fernet.generate_key() + tm = TokenManager() + + # Should not raise + tm._delete_secure_file("nonexistent.txt") + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/lib/cli/tests/test_train_crew.py b/lib/cli/tests/test_train_crew.py new file mode 100644 index 000000000..47263032e --- /dev/null +++ b/lib/cli/tests/test_train_crew.py @@ -0,0 +1,89 @@ +import subprocess +from unittest import mock + +from crewai_cli.train_crew import train_crew + + +@mock.patch("crewai_cli.train_crew.subprocess.run") +def test_train_crew_positive_iterations(mock_subprocess_run): + n_iterations = 5 + mock_subprocess_run.return_value = subprocess.CompletedProcess( + args=["uv", "run", "train", str(n_iterations)], + returncode=0, + stdout="Success", + stderr="", + ) + + train_crew(n_iterations, "trained_agents_data.pkl") + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, + ) + + +@mock.patch("crewai_cli.train_crew.click") +def test_train_crew_zero_iterations(click): + train_crew(0, "trained_agents_data.pkl") + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai_cli.train_crew.click") +def test_train_crew_negative_iterations(click): + train_crew(-2, "trained_agents_data.pkl") + click.echo.assert_called_once_with( + "An unexpected error occurred: The number of iterations must be a positive integer.", + err=True, + ) + + +@mock.patch("crewai_cli.train_crew.click") +@mock.patch("crewai_cli.train_crew.subprocess.run") +def test_train_crew_called_process_error(mock_subprocess_run, click): + n_iterations = 5 + mock_subprocess_run.side_effect = subprocess.CalledProcessError( + returncode=1, + cmd=["uv", "run", "train", str(n_iterations)], + output="Error", + stderr="Some error occurred", + ) + train_crew(n_iterations, "trained_agents_data.pkl") + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, + ) + click.echo.assert_has_calls( + [ + mock.call.echo( + "An error occurred while training the crew: Command '['uv', 'run', 'train', '5']' returned non-zero exit status 1.", + err=True, + ), + mock.call.echo("Error", err=True), + ] + ) + + +@mock.patch("crewai_cli.train_crew.click") +@mock.patch("crewai_cli.train_crew.subprocess.run") +def test_train_crew_unexpected_exception(mock_subprocess_run, click): + n_iterations = 5 + mock_subprocess_run.side_effect = Exception("Unexpected error") + train_crew(n_iterations, "trained_agents_data.pkl") + + mock_subprocess_run.assert_called_once_with( + ["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"], + capture_output=False, + text=True, + check=True, + ) + click.echo.assert_called_once_with( + "An unexpected error occurred: Unexpected error", err=True + ) diff --git a/lib/cli/tests/test_utils.py b/lib/cli/tests/test_utils.py new file mode 100644 index 000000000..41cdaaa49 --- /dev/null +++ b/lib/cli/tests/test_utils.py @@ -0,0 +1,146 @@ +import os +import shutil +import tempfile +from pathlib import Path + +import pytest +from crewai_cli import utils + + +@pytest.fixture +def temp_tree(): + root_dir = tempfile.mkdtemp() + + create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!") + create_file(os.path.join(root_dir, "file2.txt"), "Another file") + os.mkdir(os.path.join(root_dir, "empty_dir")) + nested_dir = os.path.join(root_dir, "nested_dir") + os.mkdir(nested_dir) + create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content") + + yield root_dir + + shutil.rmtree(root_dir) + + +def create_file(path, content): + with open(path, "w") as f: + f.write(content) + + +def test_tree_find_and_replace_file_content(temp_tree): + utils.tree_find_and_replace(temp_tree, "world", "universe") + with open(os.path.join(temp_tree, "file1.txt"), "r") as f: + assert f.read() == "Hello, universe!" + + +def test_tree_find_and_replace_file_name(temp_tree): + old_path = os.path.join(temp_tree, "file2.txt") + new_path = os.path.join(temp_tree, "file2_renamed.txt") + os.rename(old_path, new_path) + utils.tree_find_and_replace(temp_tree, "renamed", "modified") + assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt")) + assert not os.path.exists(new_path) + + +def test_tree_find_and_replace_directory_name(temp_tree): + utils.tree_find_and_replace(temp_tree, "empty", "renamed") + assert os.path.exists(os.path.join(temp_tree, "renamed_dir")) + assert not os.path.exists(os.path.join(temp_tree, "empty_dir")) + + +def test_tree_find_and_replace_nested_content(temp_tree): + utils.tree_find_and_replace(temp_tree, "Nested", "Updated") + with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f: + assert f.read() == "Updated content" + + +def test_tree_find_and_replace_no_matches(temp_tree): + utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement") + assert set(os.listdir(temp_tree)) == { + "file1.txt", + "file2.txt", + "empty_dir", + "nested_dir", + } + + +def test_tree_copy_full_structure(temp_tree): + dest_dir = tempfile.mkdtemp() + try: + utils.tree_copy(temp_tree, dest_dir) + assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree)) + assert os.path.isfile(os.path.join(dest_dir, "file1.txt")) + assert os.path.isfile(os.path.join(dest_dir, "file2.txt")) + assert os.path.isdir(os.path.join(dest_dir, "empty_dir")) + assert os.path.isdir(os.path.join(dest_dir, "nested_dir")) + assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt")) + finally: + shutil.rmtree(dest_dir) + + +def test_tree_copy_preserve_content(temp_tree): + dest_dir = tempfile.mkdtemp() + try: + utils.tree_copy(temp_tree, dest_dir) + with open(os.path.join(dest_dir, "file1.txt"), "r") as f: + assert f.read() == "Hello, world!" + with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f: + assert f.read() == "Nested content" + finally: + shutil.rmtree(dest_dir) + + +def test_tree_copy_to_existing_directory(temp_tree): + dest_dir = tempfile.mkdtemp() + try: + create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first") + utils.tree_copy(temp_tree, dest_dir) + assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt")) + assert os.path.isfile(os.path.join(dest_dir, "file1.txt")) + finally: + shutil.rmtree(dest_dir) + + +@pytest.fixture +def temp_project_dir(): + """Create a temporary directory for testing tool extraction.""" + with tempfile.TemporaryDirectory() as temp_dir: + yield Path(temp_dir) + + +def create_init_file(directory, content): + return create_file(directory / "__init__.py", content) + + +def test_extract_available_exports_empty_project(temp_project_dir, capsys): + with pytest.raises(SystemExit): + utils.extract_available_exports(dir_path=temp_project_dir) + captured = capsys.readouterr() + + assert "No valid tools were exposed in your __init__.py file" in captured.out + + +def test_extract_available_exports_no_init_file(temp_project_dir, capsys): + (temp_project_dir / "some_file.py").write_text("print('hello')") + with pytest.raises(SystemExit): + utils.extract_available_exports(dir_path=temp_project_dir) + captured = capsys.readouterr() + + assert "No valid tools were exposed in your __init__.py file" in captured.out + + +def test_extract_available_exports_empty_init_file(temp_project_dir, capsys): + create_init_file(temp_project_dir, "") + with pytest.raises(SystemExit): + utils.extract_available_exports(dir_path=temp_project_dir) + captured = capsys.readouterr() + + assert "Warning: No __all__ defined in" in captured.out + + +# Tests for extract_available_exports with crewai.tools (BaseTool, @tool) +# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package. + +# Tests for get_crews, get_flows, fetch_crews, is_valid_tool +# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package. diff --git a/lib/cli/tests/test_version.py b/lib/cli/tests/test_version.py new file mode 100644 index 000000000..b6794b00a --- /dev/null +++ b/lib/cli/tests/test_version.py @@ -0,0 +1,372 @@ +"""Test for version management.""" + +import json +from datetime import datetime, timedelta +from pathlib import Path +from unittest.mock import MagicMock, patch + +from crewai_cli.version import get_crewai_version as _get_ver +from crewai_cli.version import ( + _find_latest_non_yanked_version, + _get_cache_file, + _is_cache_valid, + _is_version_yanked, + get_crewai_version, + get_latest_version_from_pypi, + is_current_version_yanked, + is_newer_version_available, +) + + +def test_dynamic_versioning_consistency() -> None: + """Test that dynamic versioning provides consistent version across all access methods.""" + cli_version = get_crewai_version() + package_version = _get_ver() + + assert cli_version == package_version + + assert package_version is not None + assert len(package_version.strip()) > 0 + + +class TestVersionChecking: + """Test version checking utilities.""" + + def test_get_crewai_version(self) -> None: + """Test getting current crewai version.""" + version = get_crewai_version() + assert isinstance(version, str) + assert len(version) > 0 + + def test_get_cache_file(self) -> None: + """Test cache file path generation.""" + cache_file = _get_cache_file() + assert isinstance(cache_file, Path) + assert cache_file.name == "version_cache.json" + + def test_is_cache_valid_with_fresh_cache(self) -> None: + """Test cache validation with fresh cache.""" + cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"} + assert _is_cache_valid(cache_data) is True + + def test_is_cache_valid_with_stale_cache(self) -> None: + """Test cache validation with stale cache.""" + old_time = datetime.now() - timedelta(hours=25) + cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"} + assert _is_cache_valid(cache_data) is False + + def test_is_cache_valid_with_missing_timestamp(self) -> None: + """Test cache validation with missing timestamp.""" + cache_data = {"version": "1.0.0"} + assert _is_cache_valid(cache_data) is False + + @patch("crewai_cli.version.Path.exists") + @patch("crewai_cli.version.request.urlopen") + def test_get_latest_version_from_pypi_success( + self, mock_urlopen: MagicMock, mock_exists: MagicMock + ) -> None: + """Test successful PyPI version fetch uses releases data.""" + mock_exists.return_value = False + + releases = { + "1.0.0": [{"yanked": False}], + "2.0.0": [{"yanked": False}], + "2.1.0": [{"yanked": True, "yanked_reason": "bad release"}], + } + mock_response = MagicMock() + mock_response.read.return_value = json.dumps( + {"info": {"version": "2.1.0"}, "releases": releases} + ).encode() + mock_urlopen.return_value.__enter__.return_value = mock_response + + version = get_latest_version_from_pypi() + assert version == "2.0.0" + + @patch("crewai_cli.version.Path.exists") + @patch("crewai_cli.version.request.urlopen") + def test_get_latest_version_from_pypi_failure( + self, mock_urlopen: MagicMock, mock_exists: MagicMock + ) -> None: + """Test PyPI version fetch failure.""" + from urllib.error import URLError + + mock_exists.return_value = False + + mock_urlopen.side_effect = URLError("Network error") + + version = get_latest_version_from_pypi() + assert version is None + + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version.get_latest_version_from_pypi") + def test_is_newer_version_available_true( + self, mock_latest: MagicMock, mock_current: MagicMock + ) -> None: + """Test when newer version is available.""" + mock_current.return_value = "1.0.0" + mock_latest.return_value = "2.0.0" + + is_newer, current, latest = is_newer_version_available() + assert is_newer is True + assert current == "1.0.0" + assert latest == "2.0.0" + + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version.get_latest_version_from_pypi") + def test_is_newer_version_available_false( + self, mock_latest: MagicMock, mock_current: MagicMock + ) -> None: + """Test when no newer version is available.""" + mock_current.return_value = "2.0.0" + mock_latest.return_value = "2.0.0" + + is_newer, current, latest = is_newer_version_available() + assert is_newer is False + assert current == "2.0.0" + assert latest == "2.0.0" + + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version.get_latest_version_from_pypi") + def test_is_newer_version_available_with_none_latest( + self, mock_latest: MagicMock, mock_current: MagicMock + ) -> None: + """Test when PyPI fetch fails.""" + mock_current.return_value = "1.0.0" + mock_latest.return_value = None + + is_newer, current, latest = is_newer_version_available() + assert is_newer is False + assert current == "1.0.0" + assert latest is None + + +class TestFindLatestNonYankedVersion: + """Test _find_latest_non_yanked_version helper.""" + + def test_skips_yanked_versions(self) -> None: + """Test that yanked versions are skipped.""" + releases = { + "1.0.0": [{"yanked": False}], + "2.0.0": [{"yanked": True}], + } + assert _find_latest_non_yanked_version(releases) == "1.0.0" + + def test_returns_highest_non_yanked(self) -> None: + """Test that the highest non-yanked version is returned.""" + releases = { + "1.0.0": [{"yanked": False}], + "1.5.0": [{"yanked": False}], + "2.0.0": [{"yanked": True}], + } + assert _find_latest_non_yanked_version(releases) == "1.5.0" + + def test_returns_none_when_all_yanked(self) -> None: + """Test that None is returned when all versions are yanked.""" + releases = { + "1.0.0": [{"yanked": True}], + "2.0.0": [{"yanked": True}], + } + assert _find_latest_non_yanked_version(releases) is None + + def test_skips_prerelease_versions(self) -> None: + """Test that pre-release versions are skipped.""" + releases = { + "1.0.0": [{"yanked": False}], + "2.0.0a1": [{"yanked": False}], + "2.0.0rc1": [{"yanked": False}], + } + assert _find_latest_non_yanked_version(releases) == "1.0.0" + + def test_skips_versions_with_empty_files(self) -> None: + """Test that versions with no files are skipped.""" + releases: dict[str, list[dict[str, bool]]] = { + "1.0.0": [{"yanked": False}], + "2.0.0": [], + } + assert _find_latest_non_yanked_version(releases) == "1.0.0" + + def test_handles_invalid_version_strings(self) -> None: + """Test that invalid version strings are skipped.""" + releases = { + "1.0.0": [{"yanked": False}], + "not-a-version": [{"yanked": False}], + } + assert _find_latest_non_yanked_version(releases) == "1.0.0" + + def test_partially_yanked_files_not_considered_yanked(self) -> None: + """Test that a version with some non-yanked files is not yanked.""" + releases = { + "1.0.0": [{"yanked": False}], + "2.0.0": [{"yanked": True}, {"yanked": False}], + } + assert _find_latest_non_yanked_version(releases) == "2.0.0" + + +class TestIsVersionYanked: + """Test _is_version_yanked helper.""" + + def test_non_yanked_version(self) -> None: + """Test a non-yanked version returns False.""" + releases = {"1.0.0": [{"yanked": False}]} + is_yanked, reason = _is_version_yanked("1.0.0", releases) + assert is_yanked is False + assert reason == "" + + def test_yanked_version_with_reason(self) -> None: + """Test a yanked version returns True with reason.""" + releases = { + "1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}], + } + is_yanked, reason = _is_version_yanked("1.0.0", releases) + assert is_yanked is True + assert reason == "critical bug" + + def test_yanked_version_without_reason(self) -> None: + """Test a yanked version returns True with empty reason.""" + releases = {"1.0.0": [{"yanked": True}]} + is_yanked, reason = _is_version_yanked("1.0.0", releases) + assert is_yanked is True + assert reason == "" + + def test_unknown_version(self) -> None: + """Test an unknown version returns False.""" + releases = {"1.0.0": [{"yanked": False}]} + is_yanked, reason = _is_version_yanked("9.9.9", releases) + assert is_yanked is False + assert reason == "" + + def test_partially_yanked_files(self) -> None: + """Test a version with mixed yanked/non-yanked files is not yanked.""" + releases = { + "1.0.0": [{"yanked": True}, {"yanked": False}], + } + is_yanked, reason = _is_version_yanked("1.0.0", releases) + assert is_yanked is False + assert reason == "" + + def test_multiple_yanked_files_picks_first_reason(self) -> None: + """Test that the first available reason is returned.""" + releases = { + "1.0.0": [ + {"yanked": True, "yanked_reason": ""}, + {"yanked": True, "yanked_reason": "second reason"}, + ], + } + is_yanked, reason = _is_version_yanked("1.0.0", releases) + assert is_yanked is True + assert reason == "second reason" + + +class TestIsCurrentVersionYanked: + """Test is_current_version_yanked public function.""" + + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version._get_cache_file") + def test_reads_from_valid_cache( + self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path + ) -> None: + """Test reading yanked status from a valid cache.""" + mock_version.return_value = "1.0.0" + cache_file = tmp_path / "version_cache.json" + cache_data = { + "version": "2.0.0", + "timestamp": datetime.now().isoformat(), + "current_version": "1.0.0", + "current_version_yanked": True, + "current_version_yanked_reason": "bad release", + } + cache_file.write_text(json.dumps(cache_data)) + mock_cache_file.return_value = cache_file + + is_yanked, reason = is_current_version_yanked() + assert is_yanked is True + assert reason == "bad release" + + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version._get_cache_file") + def test_not_yanked_from_cache( + self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path + ) -> None: + """Test non-yanked status from a valid cache.""" + mock_version.return_value = "2.0.0" + cache_file = tmp_path / "version_cache.json" + cache_data = { + "version": "2.0.0", + "timestamp": datetime.now().isoformat(), + "current_version": "2.0.0", + "current_version_yanked": False, + "current_version_yanked_reason": "", + } + cache_file.write_text(json.dumps(cache_data)) + mock_cache_file.return_value = cache_file + + is_yanked, reason = is_current_version_yanked() + assert is_yanked is False + assert reason == "" + + @patch("crewai_cli.version.get_latest_version_from_pypi") + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version._get_cache_file") + def test_triggers_fetch_on_stale_cache( + self, + mock_cache_file: MagicMock, + mock_version: MagicMock, + mock_fetch: MagicMock, + tmp_path: Path, + ) -> None: + """Test that a stale cache triggers a re-fetch.""" + mock_version.return_value = "1.0.0" + cache_file = tmp_path / "version_cache.json" + old_time = datetime.now() - timedelta(hours=25) + cache_data = { + "version": "2.0.0", + "timestamp": old_time.isoformat(), + "current_version": "1.0.0", + "current_version_yanked": True, + "current_version_yanked_reason": "old reason", + } + cache_file.write_text(json.dumps(cache_data)) + mock_cache_file.return_value = cache_file + + fresh_cache = { + "version": "2.0.0", + "timestamp": datetime.now().isoformat(), + "current_version": "1.0.0", + "current_version_yanked": False, + "current_version_yanked_reason": "", + } + + def write_fresh_cache() -> str: + cache_file.write_text(json.dumps(fresh_cache)) + return "2.0.0" + + mock_fetch.side_effect = lambda: write_fresh_cache() + + is_yanked, reason = is_current_version_yanked() + assert is_yanked is False + mock_fetch.assert_called_once() + + @patch("crewai_cli.version.get_latest_version_from_pypi") + @patch("crewai_cli.version.get_crewai_version") + @patch("crewai_cli.version._get_cache_file") + def test_returns_false_on_fetch_failure( + self, + mock_cache_file: MagicMock, + mock_version: MagicMock, + mock_fetch: MagicMock, + tmp_path: Path, + ) -> None: + """Test that fetch failure returns not yanked.""" + mock_version.return_value = "1.0.0" + cache_file = tmp_path / "version_cache.json" + mock_cache_file.return_value = cache_file + mock_fetch.return_value = None + + is_yanked, reason = is_current_version_yanked() + assert is_yanked is False + assert reason == "" + + + +# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py +# as they depend on crewai.events.utils.console_formatter (core package). diff --git a/lib/cli/tests/tools/__init__.py b/lib/cli/tests/tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/lib/cli/tests/tools/test_main.py b/lib/cli/tests/tools/test_main.py new file mode 100644 index 000000000..6b2606b38 --- /dev/null +++ b/lib/cli/tests/tools/test_main.py @@ -0,0 +1,390 @@ +import os +import tempfile +import unittest +import unittest.mock +from contextlib import contextmanager +from datetime import datetime, timedelta +from pathlib import Path +from unittest import mock +from unittest.mock import MagicMock, patch + +import pytest +from crewai_cli.shared.token_manager import TokenManager +from crewai_cli.tools.main import ToolCommand +from pytest import raises + + +@contextmanager +def in_temp_dir(): + original_dir = os.getcwd() + with tempfile.TemporaryDirectory() as temp_dir: + os.chdir(temp_dir) + try: + yield temp_dir + finally: + os.chdir(original_dir) + + +@pytest.fixture +def tool_command(): + # Create a temporary directory for each test to avoid token storage conflicts + with tempfile.TemporaryDirectory() as temp_dir: + # Mock the secure storage path to use the temp directory + with patch.object( + TokenManager, "_get_secure_storage_path", return_value=Path(temp_dir) + ): + TokenManager().save_tokens( + "test-token", (datetime.now() + timedelta(seconds=36000)).timestamp() + ) + tool_command = ToolCommand() + with patch.object(tool_command, "login"): + yield tool_command + + +@patch("crewai_cli.tools.main.subprocess.run") +def test_create_success(mock_subprocess, capsys, tool_command): + with in_temp_dir(): + tool_command.create("test-tool") + output = capsys.readouterr().out + assert "Creating custom tool test_tool..." in output + + assert os.path.isdir("test_tool") + assert os.path.isfile(os.path.join("test_tool", "README.md")) + assert os.path.isfile(os.path.join("test_tool", "pyproject.toml")) + assert os.path.isfile( + os.path.join("test_tool", "src", "test_tool", "__init__.py") + ) + assert os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py")) + + with open(os.path.join("test_tool", "src", "test_tool", "tool.py"), "r") as f: + content = f.read() + assert "class TestTool" in content + + mock_subprocess.assert_called_once_with(["git", "init"], check=True) + + +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.plus_api.PlusAPI.get_tool") +@patch("crewai_cli.tools.main.ToolCommand._print_current_organization") +def test_install_success( + mock_print_org, mock_get, mock_subprocess_run, capsys, tool_command +): + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "handle": "sample-tool", + "repository": {"handle": "sample-repo", "url": "https://example.com/repo"}, + } + mock_get.return_value = mock_get_response + mock_subprocess_run.return_value = MagicMock(stderr=None) + + tool_command.install("sample-tool") + output = capsys.readouterr().out + assert "Successfully installed sample-tool" in output + + mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()]) + mock_subprocess_run.assert_any_call( + [ + "uv", + "add", + "--index", + "sample-repo=https://example.com/repo", + "sample-tool", + ], + capture_output=False, + text=True, + check=True, + env=unittest.mock.ANY, + ) + + # Verify _print_current_organization was called + mock_print_org.assert_called_once() + + +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.plus_api.PlusAPI.get_tool") +def test_install_success_from_pypi(mock_get, mock_subprocess_run, capsys, tool_command): + mock_get_response = MagicMock() + mock_get_response.status_code = 200 + mock_get_response.json.return_value = { + "handle": "sample-tool", + "repository": {"handle": "sample-repo", "url": "https://example.com/repo"}, + "source": "pypi", + } + mock_get.return_value = mock_get_response + mock_subprocess_run.return_value = MagicMock(stderr=None) + + tool_command.install("sample-tool") + output = capsys.readouterr().out + assert "Successfully installed sample-tool" in output + + mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()]) + mock_subprocess_run.assert_any_call( + [ + "uv", + "add", + "sample-tool", + ], + capture_output=False, + text=True, + check=True, + env=unittest.mock.ANY, + ) + + +@patch("crewai_cli.plus_api.PlusAPI.get_tool") +def test_install_tool_not_found(mock_get, capsys, tool_command): + mock_get_response = MagicMock() + mock_get_response.status_code = 404 + mock_get.return_value = mock_get_response + + with raises(SystemExit): + tool_command.install("non-existent-tool") + output = capsys.readouterr().out + assert "No tool found with this name" in output + + mock_get.assert_called_once_with("non-existent-tool") + + +@patch("crewai_cli.plus_api.PlusAPI.get_tool") +def test_install_api_error(mock_get, capsys, tool_command): + mock_get_response = MagicMock() + mock_get_response.status_code = 500 + mock_get.return_value = mock_get_response + + with raises(SystemExit): + tool_command.install("error-tool") + output = capsys.readouterr().out + assert "Failed to get tool details" in output + + mock_get.assert_called_once_with("error-tool") + + +@patch("crewai_cli.tools.main.git.Repository.fetch") +@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=False) +def test_publish_when_not_in_sync(mock_is_synced, mock_fetch, capsys, tool_command): + with raises(SystemExit): + tool_command.publish(is_public=True) + + output = capsys.readouterr().out + assert "Local changes need to be resolved before publishing" in output + + +@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool") +@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0") +@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool") +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]) +@patch( + "crewai_cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", +) +@patch("crewai_cli.tools.main.git.Repository.fetch") +@patch("crewai_cli.plus_api.PlusAPI.publish_tool") +@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=False) +@patch( + "crewai_cli.tools.main.extract_available_exports", + return_value=[{"name": "SampleTool"}], +) +@patch("crewai_cli.tools.main.ToolCommand._print_current_organization") +def test_publish_when_not_in_sync_and_force( + mock_print_org, + mock_available_exports, + mock_is_synced, + mock_publish, + mock_fetch, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + tool_command, +): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 200 + mock_publish_response.json.return_value = {"handle": "sample-tool"} + mock_publish.return_value = mock_publish_response + + tool_command.publish(is_public=True, force=True) + + mock_get_project_name.assert_called_with(require=True) + mock_get_project_version.assert_called_with(require=True) + mock_get_project_description.assert_called_with(require=False) + mock_subprocess_run.assert_called_with( + ["uv", "build", "--sdist", "--out-dir", unittest.mock.ANY], + check=True, + capture_output=False, + ) + mock_open.assert_called_with(unittest.mock.ANY, "rb") + mock_publish.assert_called_with( + handle="sample-tool", + is_public=True, + version="1.0.0", + description="A sample tool", + encoded_file=unittest.mock.ANY, + available_exports=[{"name": "SampleTool"}], + ) + mock_print_org.assert_called_once() + + +@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool") +@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0") +@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool") +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]) +@patch( + "crewai_cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", +) +@patch("crewai_cli.tools.main.git.Repository.fetch") +@patch("crewai_cli.plus_api.PlusAPI.publish_tool") +@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=True) +@patch( + "crewai_cli.tools.main.extract_available_exports", + return_value=[{"name": "SampleTool"}], +) +def test_publish_success( + mock_available_exports, + mock_is_synced, + mock_publish, + mock_fetch, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + tool_command, +): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 200 + mock_publish_response.json.return_value = {"handle": "sample-tool"} + mock_publish.return_value = mock_publish_response + + tool_command.publish(is_public=True) + + mock_get_project_name.assert_called_with(require=True) + mock_get_project_version.assert_called_with(require=True) + mock_get_project_description.assert_called_with(require=False) + mock_subprocess_run.assert_called_with( + ["uv", "build", "--sdist", "--out-dir", unittest.mock.ANY], + check=True, + capture_output=False, + ) + mock_open.assert_called_with(unittest.mock.ANY, "rb") + mock_publish.assert_called_with( + handle="sample-tool", + is_public=True, + version="1.0.0", + description="A sample tool", + encoded_file=unittest.mock.ANY, + available_exports=[{"name": "SampleTool"}], + ) + + +@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool") +@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0") +@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool") +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]) +@patch( + "crewai_cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", +) +@patch("crewai_cli.plus_api.PlusAPI.publish_tool") +@patch( + "crewai_cli.tools.main.extract_available_exports", + return_value=[{"name": "SampleTool"}], +) +def test_publish_failure( + mock_available_exports, + mock_publish, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + capsys, + tool_command, +): + mock_publish_response = MagicMock() + mock_publish_response.status_code = 422 + mock_publish_response.json.return_value = {"name": ["is already taken"]} + mock_publish.return_value = mock_publish_response + + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out + assert "Failed to complete operation" in output + assert "Name is already taken" in output + + mock_publish.assert_called_once() + + +@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool") +@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0") +@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool") +@patch("crewai_cli.tools.main.subprocess.run") +@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"]) +@patch( + "crewai_cli.tools.main.open", + new_callable=unittest.mock.mock_open, + read_data=b"sample tarball content", +) +@patch("crewai_cli.plus_api.PlusAPI.publish_tool") +@patch( + "crewai_cli.tools.main.extract_available_exports", + return_value=[{"name": "SampleTool"}], +) +def test_publish_api_error( + mock_available_exports, + mock_publish, + mock_open, + mock_listdir, + mock_subprocess_run, + mock_get_project_description, + mock_get_project_version, + mock_get_project_name, + capsys, + tool_command, +): + mock_response = MagicMock() + mock_response.status_code = 500 + mock_response.json.return_value = {"error": "Internal Server Error"} + mock_response.is_success = False + mock_publish.return_value = mock_response + + with raises(SystemExit): + tool_command.publish(is_public=True) + output = capsys.readouterr().out + assert "Request to Enterprise API failed" in output + + mock_publish.assert_called_once() + + +@patch("crewai_cli.tools.main.Settings") +def test_print_current_organization_with_org(mock_settings, capsys, tool_command): + mock_settings_instance = MagicMock() + mock_settings_instance.org_uuid = "test-org-uuid" + mock_settings_instance.org_name = "Test Organization" + mock_settings.return_value = mock_settings_instance + tool_command._print_current_organization() + output = capsys.readouterr().out + assert "Current organization: Test Organization (test-org-uuid)" in output + + +@patch("crewai_cli.tools.main.Settings") +def test_print_current_organization_without_org(mock_settings, capsys, tool_command): + mock_settings_instance = MagicMock() + mock_settings_instance.org_uuid = None + mock_settings_instance.org_name = None + mock_settings.return_value = mock_settings_instance + tool_command._print_current_organization() + output = capsys.readouterr().out + assert "No organization currently set" in output + assert "org switch " in output diff --git a/lib/cli/tests/triggers/test_main.py b/lib/cli/tests/triggers/test_main.py new file mode 100644 index 000000000..dc754c003 --- /dev/null +++ b/lib/cli/tests/triggers/test_main.py @@ -0,0 +1,170 @@ +import json +import subprocess +import unittest +from unittest.mock import Mock, patch + +import httpx +from crewai_cli.triggers.main import TriggersCommand + + +class TestTriggersCommand(unittest.TestCase): + @patch("crewai_cli.command.get_auth_token") + @patch("crewai_cli.command.PlusAPI") + def setUp(self, mock_plus_api, mock_get_auth_token): + self.mock_get_auth_token = mock_get_auth_token + self.mock_plus_api = mock_plus_api + + self.mock_get_auth_token.return_value = "test_token" + + self.triggers_command = TriggersCommand() + self.mock_client = self.triggers_command.plus_api_client + + @patch("crewai_cli.triggers.main.console.print") + def test_list_triggers_success(self, mock_console_print): + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.ok = True + mock_response.json.return_value = { + "apps": [ + { + "name": "Test App", + "slug": "test-app", + "description": "A test application", + "is_connected": True, + "triggers": [ + { + "name": "Test Trigger", + "slug": "test-trigger", + "description": "A test trigger" + } + ] + } + ] + } + self.mock_client.get_triggers.return_value = mock_response + + self.triggers_command.list_triggers() + + self.mock_client.get_triggers.assert_called_once() + mock_console_print.assert_any_call("[bold blue]Fetching available triggers...[/bold blue]") + + @patch("crewai_cli.triggers.main.console.print") + def test_list_triggers_no_apps(self, mock_console_print): + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.ok = True + mock_response.json.return_value = {"apps": []} + self.mock_client.get_triggers.return_value = mock_response + + self.triggers_command.list_triggers() + + mock_console_print.assert_any_call("[yellow]No triggers found.[/yellow]") + + @patch("crewai_cli.triggers.main.console.print") + def test_list_triggers_api_error(self, mock_console_print): + self.mock_client.get_triggers.side_effect = Exception("API Error") + + with self.assertRaises(SystemExit): + self.triggers_command.list_triggers() + + mock_console_print.assert_any_call("[bold red]Error fetching triggers: API Error[/bold red]") + + @patch("crewai_cli.triggers.main.console.print") + def test_execute_with_trigger_invalid_format(self, mock_console_print): + with self.assertRaises(SystemExit): + self.triggers_command.execute_with_trigger("invalid-format") + + mock_console_print.assert_called_with( + "[bold red]Error: Trigger must be in format 'app_slug/trigger_slug'[/bold red]" + ) + + @patch("crewai_cli.triggers.main.console.print") + @patch.object(TriggersCommand, "_run_crew_with_payload") + def test_execute_with_trigger_success(self, mock_run_crew, mock_console_print): + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_response.ok = True + mock_response.json.return_value = { + "sample_payload": {"key": "value", "data": "test"} + } + self.mock_client.get_trigger_payload.return_value = mock_response + + self.triggers_command.execute_with_trigger("test-app/test-trigger") + + self.mock_client.get_trigger_payload.assert_called_once_with("test-app", "test-trigger") + mock_run_crew.assert_called_once_with({"key": "value", "data": "test"}) + mock_console_print.assert_any_call( + "[bold blue]Fetching trigger payload for test-app/test-trigger...[/bold blue]" + ) + + @patch("crewai_cli.triggers.main.console.print") + def test_execute_with_trigger_not_found(self, mock_console_print): + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.json.return_value = {"error": "Trigger not found"} + self.mock_client.get_trigger_payload.return_value = mock_response + + with self.assertRaises(SystemExit): + self.triggers_command.execute_with_trigger("test-app/nonexistent-trigger") + + mock_console_print.assert_any_call("[bold red]Error: Trigger not found[/bold red]") + + @patch("crewai_cli.triggers.main.console.print") + def test_execute_with_trigger_api_error(self, mock_console_print): + self.mock_client.get_trigger_payload.side_effect = Exception("API Error") + + with self.assertRaises(SystemExit): + self.triggers_command.execute_with_trigger("test-app/test-trigger") + + mock_console_print.assert_any_call( + "[bold red]Error executing crew with trigger: API Error[/bold red]" + ) + + + @patch("subprocess.run") + def test_run_crew_with_payload_success(self, mock_subprocess): + payload = {"key": "value", "data": "test"} + mock_subprocess.return_value = None + + self.triggers_command._run_crew_with_payload(payload) + + mock_subprocess.assert_called_once_with( + ["uv", "run", "run_with_trigger", json.dumps(payload)], + capture_output=False, + text=True, + check=True + ) + + @patch("subprocess.run") + def test_run_crew_with_payload_failure(self, mock_subprocess): + payload = {"key": "value"} + mock_subprocess.side_effect = subprocess.CalledProcessError(1, "uv") + + with self.assertRaises(SystemExit): + self.triggers_command._run_crew_with_payload(payload) + + @patch("subprocess.run") + def test_run_crew_with_payload_empty_payload(self, mock_subprocess): + payload = {} + mock_subprocess.return_value = None + + self.triggers_command._run_crew_with_payload(payload) + + mock_subprocess.assert_called_once_with( + ["uv", "run", "run_with_trigger", "{}"], + capture_output=False, + text=True, + check=True + ) + + @patch("crewai_cli.triggers.main.console.print") + def test_execute_with_trigger_with_default_error_message(self, mock_console_print): + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 404 + mock_response.json.return_value = {} + self.mock_client.get_trigger_payload.return_value = mock_response + + with self.assertRaises(SystemExit): + self.triggers_command.execute_with_trigger("test-app/test-trigger") + + mock_console_print.assert_any_call("[bold red]Error: Trigger not found[/bold red]")