mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
test: add CLI tests to crewai-cli package
Move and adapt all CLI tests from lib/crewai/tests/cli/ to lib/cli/tests/, updating import paths from crewai.cli.* to crewai_cli.* and adjusting mock targets accordingly.
This commit is contained in:
0
lib/cli/tests/__init__.py
Normal file
0
lib/cli/tests/__init__.py
Normal file
0
lib/cli/tests/authentication/__init__.py
Normal file
0
lib/cli/tests/authentication/__init__.py
Normal file
0
lib/cli/tests/authentication/providers/__init__.py
Normal file
0
lib/cli/tests/authentication/providers/__init__.py
Normal file
91
lib/cli/tests/authentication/providers/test_auth0.py
Normal file
91
lib/cli/tests/authentication/providers/test_auth0.py
Normal file
@@ -0,0 +1,91 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.auth0 import Auth0Provider
|
||||
|
||||
|
||||
|
||||
class TestAuth0Provider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="test-domain.auth0.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = Auth0Provider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = Auth0Provider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "auth0"
|
||||
assert provider.settings.domain == "test-domain.auth0.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/device/code"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="my-company.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://my-company.auth0.com/oauth/device/code"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/oauth/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="another-domain.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://another-domain.auth0.com/oauth/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.auth0.com/.well-known/jwks.json"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="dev.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_url = "https://dev.auth0.com/.well-known/jwks.json"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.auth0.com/"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="auth0",
|
||||
domain="prod.auth0.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = Auth0Provider(settings)
|
||||
expected_issuer = "https://prod.auth0.com/"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
141
lib/cli/tests/authentication/providers/test_entra_id.py
Normal file
141
lib/cli/tests/authentication/providers/test_entra_id.py
Normal file
@@ -0,0 +1,141 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.entra_id import EntraIdProvider
|
||||
|
||||
|
||||
class TestEntraIdProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "openid profile email api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
self.provider = EntraIdProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = EntraIdProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "entra_id"
|
||||
assert provider.settings.domain == "tenant-id-abcdef123456"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/devicecode"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="my-company.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/my-company.entra.id/oauth2/v2.0/devicecode"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/oauth2/v2.0/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="another-domain.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/another-domain.entra.id/oauth2/v2.0/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.microsoftonline.com/tenant-id-abcdef123456/discovery/v2.0/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="dev.entra.id",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_url = "https://login.microsoftonline.com/dev.entra.id/discovery/v2.0/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.microsoftonline.com/tenant-id-abcdef123456/v2.0"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
# For EntraID, the domain is the tenant ID.
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="other-tenant-id-xpto",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
expected_issuer = "https://login.microsoftonline.com/other-tenant-id-xpto/v2.0"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="test-tenant-id",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["scope"])
|
||||
|
||||
def test_get_oauth_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read"]
|
||||
|
||||
def test_get_oauth_scopes_with_multiple_custom_scopes(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="entra_id",
|
||||
domain="tenant-id-abcdef123456",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"scope": "api://crewai-cli-dev/read api://crewai-cli-dev/write custom-scope1 custom-scope2"
|
||||
}
|
||||
)
|
||||
provider = EntraIdProvider(settings)
|
||||
assert provider.get_oauth_scopes() == ["openid", "profile", "email", "api://crewai-cli-dev/read", "api://crewai-cli-dev/write", "custom-scope1", "custom-scope2"]
|
||||
|
||||
def test_base_url(self):
|
||||
assert self.provider._base_url() == "https://login.microsoftonline.com/tenant-id-abcdef123456"
|
||||
138
lib/cli/tests/authentication/providers/test_keycloak.py
Normal file
138
lib/cli/tests/authentication/providers/test_keycloak.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.keycloak import KeycloakProvider
|
||||
|
||||
|
||||
class TestKeycloakProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
self.provider = KeycloakProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = KeycloakProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "keycloak"
|
||||
assert provider.settings.domain == "keycloak.example.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
assert provider.settings.extra.get("realm") == "test-realm"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/auth/device"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="auth.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "my-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://auth.company.com/realms/my-realm/protocol/openid-connect/auth/device"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="sso.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "enterprise-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://sso.enterprise.com/realms/enterprise-realm/protocol/openid-connect/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://keycloak.example.com/realms/test-realm/protocol/openid-connect/certs"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="identity.org",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "org-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_url = "https://identity.org/realms/org-realm/protocol/openid-connect/certs"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://keycloak.example.com/realms/test-realm"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="login.myapp.io",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "app-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
expected_issuer = "https://login.myapp.io/realms/app-realm"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert self.provider.get_required_fields() == ["realm"]
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_https_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="https://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
|
||||
def test_oauth2_base_url_strips_http_prefix(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="keycloak",
|
||||
domain="http://keycloak.example.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
extra={
|
||||
"realm": "test-realm"
|
||||
}
|
||||
)
|
||||
provider = KeycloakProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://keycloak.example.com"
|
||||
257
lib/cli/tests/authentication/providers/test_okta.py
Normal file
257
lib/cli/tests/authentication/providers/test_okta.py
Normal file
@@ -0,0 +1,257 @@
|
||||
import pytest
|
||||
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.okta import OktaProvider
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience",
|
||||
)
|
||||
self.provider = OktaProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = OktaProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "okta"
|
||||
assert provider.settings.domain == "test-domain.okta.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="my-company.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://my-company.okta.com/oauth2/default/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/device/authorize"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="another-domain.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://another-domain.okta.com/oauth2/default/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://test-domain.okta.com/oauth2/default/v1/keys"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="dev.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://dev.okta.com/oauth2/default/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_url = "https://test-domain.okta.com/oauth2/v1/keys"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/default"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="prod.okta.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience",
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://prod.okta.com/oauth2/default"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
expected_issuer = "https://test-domain.okta.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_assertion_error_when_none(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
|
||||
with pytest.raises(ValueError, match="Audience is required"):
|
||||
provider.get_audience()
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
|
||||
def test_get_required_fields(self):
|
||||
assert set(self.provider.get_required_fields()) == set(["authorization_server_name", "using_org_auth_server"])
|
||||
|
||||
def test_oauth2_base_url(self):
|
||||
assert self.provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/default"
|
||||
|
||||
def test_oauth2_base_url_with_custom_authorization_server_name(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": False,
|
||||
"authorization_server_name": "my_auth_server_xxxAAA777"
|
||||
}
|
||||
)
|
||||
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2/my_auth_server_xxxAAA777"
|
||||
|
||||
def test_oauth2_base_url_when_using_org_auth_server(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="okta",
|
||||
domain="test-domain.okta.com",
|
||||
client_id="test-client-id",
|
||||
audience=None,
|
||||
extra={
|
||||
"using_org_auth_server": True,
|
||||
"authorization_server_name": None
|
||||
}
|
||||
)
|
||||
provider = OktaProvider(settings)
|
||||
assert provider._oauth2_base_url() == "https://test-domain.okta.com/oauth2"
|
||||
100
lib/cli/tests/authentication/providers/test_workos.py
Normal file
100
lib/cli/tests/authentication/providers/test_workos.py
Normal file
@@ -0,0 +1,100 @@
|
||||
import pytest
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
|
||||
class TestWorkosProvider:
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def setup_method(self):
|
||||
self.valid_settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience="test-audience"
|
||||
)
|
||||
self.provider = WorkosProvider(self.valid_settings)
|
||||
|
||||
def test_initialization_with_valid_settings(self):
|
||||
provider = WorkosProvider(self.valid_settings)
|
||||
assert provider.settings == self.valid_settings
|
||||
assert provider.settings.provider == "workos"
|
||||
assert provider.settings.domain == "login.company.com"
|
||||
assert provider.settings.client_id == "test-client-id"
|
||||
assert provider.settings.audience == "test-audience"
|
||||
|
||||
def test_get_authorize_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/device_authorization"
|
||||
assert self.provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_authorize_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.example.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://login.example.com/oauth2/device_authorization"
|
||||
assert provider.get_authorize_url() == expected_url
|
||||
|
||||
def test_get_token_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/token"
|
||||
assert self.provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_token_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="api.workos.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://api.workos.com/oauth2/token"
|
||||
assert provider.get_token_url() == expected_url
|
||||
|
||||
def test_get_jwks_url(self):
|
||||
expected_url = "https://login.company.com/oauth2/jwks"
|
||||
assert self.provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_jwks_url_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="auth.enterprise.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_url = "https://auth.enterprise.com/oauth2/jwks"
|
||||
assert provider.get_jwks_url() == expected_url
|
||||
|
||||
def test_get_issuer(self):
|
||||
expected_issuer = "https://login.company.com"
|
||||
assert self.provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_issuer_with_different_domain(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="sso.company.com",
|
||||
client_id="test-client",
|
||||
audience="test-audience"
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
expected_issuer = "https://sso.company.com"
|
||||
assert provider.get_issuer() == expected_issuer
|
||||
|
||||
def test_get_audience(self):
|
||||
assert self.provider.get_audience() == "test-audience"
|
||||
|
||||
def test_get_audience_fallback_to_default(self):
|
||||
settings = Oauth2Settings(
|
||||
provider="workos",
|
||||
domain="login.company.com",
|
||||
client_id="test-client-id",
|
||||
audience=None
|
||||
)
|
||||
provider = WorkosProvider(settings)
|
||||
assert provider.get_audience() == ""
|
||||
|
||||
def test_get_client_id(self):
|
||||
assert self.provider.get_client_id() == "test-client-id"
|
||||
348
lib/cli/tests/authentication/test_auth_main.py
Normal file
348
lib/cli/tests/authentication/test_auth_main.py
Normal file
@@ -0,0 +1,348 @@
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock, call, patch
|
||||
|
||||
import pytest
|
||||
import httpx
|
||||
from crewai_cli.authentication.main import AuthenticationCommand
|
||||
from crewai_cli.constants import (
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
)
|
||||
|
||||
|
||||
class TestAuthenticationCommand:
|
||||
def setup_method(self):
|
||||
# Mock Settings so we always use default constants regardless of local config.
|
||||
with patch("crewai_cli.authentication.main.Settings") as mock_settings:
|
||||
instance = mock_settings.return_value
|
||||
instance.oauth2_provider = "workos"
|
||||
instance.oauth2_domain = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN
|
||||
instance.oauth2_client_id = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID
|
||||
instance.oauth2_audience = CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE
|
||||
instance.oauth2_extra = {}
|
||||
self.auth_command = AuthenticationCommand()
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,expected_urls",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"device_code_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/device_authorization",
|
||||
"token_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/token",
|
||||
"client_id": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_CLIENT_ID,
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
"domain": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._get_device_code")
|
||||
@patch(
|
||||
"crewai_cli.authentication.main.AuthenticationCommand._display_auth_instructions"
|
||||
)
|
||||
@patch("crewai_cli.authentication.main.AuthenticationCommand._poll_for_token")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login(
|
||||
self,
|
||||
mock_console_print,
|
||||
mock_poll,
|
||||
mock_display,
|
||||
mock_get_device,
|
||||
user_provider,
|
||||
expected_urls,
|
||||
):
|
||||
mock_get_device.return_value = {
|
||||
"device_code": "test_code",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command.login()
|
||||
|
||||
mock_console_print.assert_called_once_with(
|
||||
"Signing in to CrewAI AMP...\n", style="bold blue"
|
||||
)
|
||||
mock_get_device.assert_called_once()
|
||||
mock_display.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"}
|
||||
)
|
||||
mock_poll.assert_called_once_with(
|
||||
{"device_code": "test_code", "user_code": "123456"},
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_client_id()
|
||||
== expected_urls["client_id"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider.get_audience()
|
||||
== expected_urls["audience"]
|
||||
)
|
||||
assert (
|
||||
self.auth_command.oauth2_provider._get_domain() == expected_urls["domain"]
|
||||
)
|
||||
|
||||
@patch("crewai_cli.authentication.main.webbrowser")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_display_auth_instructions(self, mock_console_print, mock_webbrowser):
|
||||
device_code_data = {
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
"user_code": "123456",
|
||||
}
|
||||
|
||||
self.auth_command._display_auth_instructions(device_code_data)
|
||||
|
||||
expected_calls = [
|
||||
call("1. Navigate to: ", "https://example.com/auth"),
|
||||
call("2. Enter the following code: ", "123456"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
mock_webbrowser.open.assert_called_once_with("https://example.com/auth")
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"user_provider,jwt_config",
|
||||
[
|
||||
(
|
||||
"workos",
|
||||
{
|
||||
"jwks_url": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}/oauth2/jwks",
|
||||
"issuer": f"https://{CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN}",
|
||||
"audience": CREWAI_ENTERPRISE_DEFAULT_OAUTH2_AUDIENCE,
|
||||
},
|
||||
),
|
||||
],
|
||||
)
|
||||
@pytest.mark.parametrize("has_expiration", [True, False])
|
||||
@patch("crewai_cli.authentication.main.validate_jwt_token")
|
||||
@patch("crewai_cli.authentication.main.TokenManager.save_tokens")
|
||||
def test_validate_and_save_token(
|
||||
self,
|
||||
mock_save_tokens,
|
||||
mock_validate_jwt,
|
||||
user_provider,
|
||||
jwt_config,
|
||||
has_expiration,
|
||||
):
|
||||
from crewai_cli.authentication.main import Oauth2Settings
|
||||
from crewai_cli.authentication.providers.workos import WorkosProvider
|
||||
|
||||
if user_provider == "workos":
|
||||
self.auth_command.oauth2_provider = WorkosProvider(
|
||||
settings=Oauth2Settings(
|
||||
provider=user_provider,
|
||||
client_id="test-client-id",
|
||||
domain=CREWAI_ENTERPRISE_DEFAULT_OAUTH2_DOMAIN,
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
)
|
||||
|
||||
token_data = {"access_token": "test_access_token", "id_token": "test_id_token"}
|
||||
|
||||
if has_expiration:
|
||||
future_timestamp = int((datetime.now() + timedelta(days=100)).timestamp())
|
||||
decoded_token = {"exp": future_timestamp}
|
||||
else:
|
||||
decoded_token = {}
|
||||
|
||||
mock_validate_jwt.return_value = decoded_token
|
||||
|
||||
self.auth_command._validate_and_save_token(token_data)
|
||||
|
||||
mock_validate_jwt.assert_called_once_with(
|
||||
jwt_token="test_access_token",
|
||||
jwks_url=jwt_config["jwks_url"],
|
||||
issuer=jwt_config["issuer"],
|
||||
audience=jwt_config["audience"],
|
||||
)
|
||||
|
||||
if has_expiration:
|
||||
mock_save_tokens.assert_called_once_with(
|
||||
"test_access_token", future_timestamp
|
||||
)
|
||||
else:
|
||||
mock_save_tokens.assert_called_once_with("test_access_token", 0)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_cli.authentication.main.Settings")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login_to_tool_repository_success(
|
||||
self, mock_console_print, mock_settings, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_name = "Test Org"
|
||||
mock_settings_instance.org_uuid = "test-uuid-123"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call("Success!\n", style="bold green"),
|
||||
call(
|
||||
"You are now authenticated to the tool repository for organization [bold cyan]'Test Org'[/bold cyan]",
|
||||
style="green",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.tools.main.ToolCommand")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_login_to_tool_repository_error(
|
||||
self, mock_console_print, mock_tool_command
|
||||
):
|
||||
mock_tool_instance = MagicMock()
|
||||
mock_tool_instance.login.side_effect = Exception("Tool repository error")
|
||||
mock_tool_command.return_value = mock_tool_instance
|
||||
|
||||
self.auth_command._login_to_tool_repository()
|
||||
|
||||
mock_tool_command.assert_called_once()
|
||||
mock_tool_instance.login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call(
|
||||
"Now logging you in to the Tool Repository... ",
|
||||
style="bold blue",
|
||||
end="",
|
||||
),
|
||||
call(
|
||||
"\n[bold yellow]Warning:[/bold yellow] Authentication with the Tool Repository failed.",
|
||||
style="yellow",
|
||||
),
|
||||
call(
|
||||
"Other features will work normally, but you may experience limitations with downloading and publishing tools.\nRun [bold]crewai login[/bold] to try logging in again.\n",
|
||||
style="yellow",
|
||||
),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
def test_get_device_code(self, mock_post):
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
self.auth_command.oauth2_provider.get_authorize_url.return_value = (
|
||||
"https://example.com/device"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_audience.return_value = "test_audience"
|
||||
self.auth_command.oauth2_provider.get_oauth_scopes.return_value = ["openid", "profile", "email"]
|
||||
|
||||
result = self.auth_command._get_device_code()
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
url="https://example.com/device",
|
||||
data={
|
||||
"client_id": "test_client",
|
||||
"scope": "openid profile email",
|
||||
"audience": "test_audience",
|
||||
},
|
||||
timeout=20,
|
||||
)
|
||||
|
||||
assert result == {
|
||||
"device_code": "test_device_code",
|
||||
"user_code": "123456",
|
||||
"verification_uri_complete": "https://example.com/auth",
|
||||
}
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_poll_for_token_success(self, mock_console_print, mock_post):
|
||||
mock_response_success = MagicMock()
|
||||
mock_response_success.status_code = 200
|
||||
mock_response_success.json.return_value = {
|
||||
"access_token": "test_access_token",
|
||||
"id_token": "test_id_token",
|
||||
}
|
||||
mock_post.return_value = mock_response_success
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
self.auth_command, "_validate_and_save_token"
|
||||
) as mock_validate,
|
||||
patch.object(
|
||||
self.auth_command, "_login_to_tool_repository"
|
||||
) as mock_tool_login,
|
||||
):
|
||||
self.auth_command.oauth2_provider = MagicMock()
|
||||
self.auth_command.oauth2_provider.get_token_url.return_value = (
|
||||
"https://example.com/token"
|
||||
)
|
||||
self.auth_command.oauth2_provider.get_client_id.return_value = "test_client"
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_post.assert_called_once_with(
|
||||
"https://example.com/token",
|
||||
data={
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code",
|
||||
"device_code": "test_device_code",
|
||||
"client_id": "test_client",
|
||||
},
|
||||
timeout=30,
|
||||
)
|
||||
|
||||
mock_validate.assert_called_once()
|
||||
mock_tool_login.assert_called_once()
|
||||
|
||||
expected_calls = [
|
||||
call("\nWaiting for authentication... ", style="bold blue", end=""),
|
||||
call("Success!", style="bold green"),
|
||||
call("\n[bold green]Welcome to CrewAI AMP![/bold green]\n"),
|
||||
]
|
||||
mock_console_print.assert_has_calls(expected_calls)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
@patch("crewai_cli.authentication.main.console.print")
|
||||
def test_poll_for_token_timeout(self, mock_console_print, mock_post):
|
||||
mock_response_pending = MagicMock()
|
||||
mock_response_pending.status_code = 400
|
||||
mock_response_pending.json.return_value = {"error": "authorization_pending"}
|
||||
mock_post.return_value = mock_response_pending
|
||||
|
||||
device_code_data = {
|
||||
"device_code": "test_device_code",
|
||||
"interval": 0.1, # Short interval for testing
|
||||
}
|
||||
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"Timeout: Failed to get the token. Please try again.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.authentication.main.httpx.post")
|
||||
def test_poll_for_token_error(self, mock_post):
|
||||
"""Test the method to poll for token (error path)."""
|
||||
# Setup mock to return error
|
||||
mock_response_error = MagicMock()
|
||||
mock_response_error.status_code = 400
|
||||
mock_response_error.json.return_value = {
|
||||
"error": "access_denied",
|
||||
"error_description": "User denied access",
|
||||
}
|
||||
mock_post.return_value = mock_response_error
|
||||
|
||||
device_code_data = {"device_code": "test_device_code", "interval": 1}
|
||||
|
||||
with pytest.raises(httpx.HTTPError):
|
||||
self.auth_command._poll_for_token(device_code_data)
|
||||
107
lib/cli/tests/authentication/test_utils.py
Normal file
107
lib/cli/tests/authentication/test_utils.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai_cli.authentication.utils import validate_jwt_token
|
||||
|
||||
|
||||
@patch("crewai_cli.authentication.utils.PyJWKClient", return_value=MagicMock())
|
||||
@patch("crewai_cli.authentication.utils.jwt")
|
||||
class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
# Create signing key object mock with a .key attribute
|
||||
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token=jwt_token,
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
jwt_token,
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
"verify_nbf": True,
|
||||
"verify_iat": True,
|
||||
"require": ["exp", "iat", "iss", "aud", "sub"],
|
||||
},
|
||||
)
|
||||
mock_pyjwkclient.assert_called_once_with("https://mock_jwks_url")
|
||||
self.assertEqual(decoded_token, {"exp": 1719859200})
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_missing_required_claims(
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
1
lib/cli/tests/deploy/__init__.py
Normal file
1
lib/cli/tests/deploy/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Tests for CLI deploy."""
|
||||
266
lib/cli/tests/deploy/test_deploy_main.py
Normal file
266
lib/cli/tests/deploy/test_deploy_main.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from crewai_cli.deploy.main import DeployCommand
|
||||
from crewai_cli.utils import parse_toml
|
||||
|
||||
|
||||
class TestDeployCommand(unittest.TestCase):
|
||||
@patch("crewai_cli.command.get_auth_token")
|
||||
@patch("crewai_cli.deploy.main.get_project_name")
|
||||
@patch("crewai_cli.command.PlusAPI")
|
||||
def setUp(self, mock_plus_api, mock_get_project_name, mock_get_auth_token):
|
||||
self.mock_get_auth_token = mock_get_auth_token
|
||||
self.mock_get_project_name = mock_get_project_name
|
||||
self.mock_plus_api = mock_plus_api
|
||||
|
||||
self.mock_get_auth_token.return_value = "test_token"
|
||||
self.mock_get_project_name.return_value = "test_project"
|
||||
|
||||
self.deploy_command = DeployCommand()
|
||||
self.mock_client = self.deploy_command.plus_api_client
|
||||
|
||||
def test_init_success(self):
|
||||
self.assertEqual(self.deploy_command.project_name, "test_project")
|
||||
self.mock_plus_api.assert_called_once_with(api_key="test_token")
|
||||
|
||||
@patch("crewai_cli.command.get_auth_token")
|
||||
def test_init_failure(self, mock_get_auth_token):
|
||||
mock_get_auth_token.side_effect = Exception("Auth failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
DeployCommand()
|
||||
|
||||
def test_validate_response_successful_response(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"message": "Success"}
|
||||
mock_response.status_code = 200
|
||||
mock_response.is_success = True
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
assert fake_out.getvalue() == ""
|
||||
|
||||
def test_validate_response_json_decode_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Decode error", "", 0)
|
||||
mock_response.status_code = 500
|
||||
mock_response.content = b"Invalid JSON"
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to parse response from Enterprise API failed. Details:"
|
||||
in output
|
||||
)
|
||||
assert "Status Code: 500" in output
|
||||
assert "Response:\nInvalid JSON" in output
|
||||
|
||||
def test_validate_response_422_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {
|
||||
"field1": ["Error message 1"],
|
||||
"field2": ["Error message 2"],
|
||||
}
|
||||
mock_response.status_code = 422
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert (
|
||||
"Failed to complete operation. Please fix the following errors:"
|
||||
in output
|
||||
)
|
||||
assert "Field1 Error message 1" in output
|
||||
assert "Field2 Error message 2" in output
|
||||
|
||||
def test_validate_response_other_error(self):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.json.return_value = {"error": "Something went wrong"}
|
||||
mock_response.status_code = 500
|
||||
mock_response.is_success = False
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
with pytest.raises(SystemExit):
|
||||
self.deploy_command._validate_response(mock_response)
|
||||
output = fake_out.getvalue()
|
||||
assert "Request to Enterprise API failed. Details:" in output
|
||||
assert "Details:\nSomething went wrong" in output
|
||||
|
||||
def test_standard_no_param_error_message(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._standard_no_param_error_message()
|
||||
self.assertIn("No UUID provided", fake_out.getvalue())
|
||||
|
||||
def test_display_deployment_info(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._display_deployment_info(
|
||||
{"uuid": "test-uuid", "status": "deployed"}
|
||||
)
|
||||
self.assertIn("Deploying the crew...", fake_out.getvalue())
|
||||
self.assertIn("test-uuid", fake_out.getvalue())
|
||||
self.assertIn("deployed", fake_out.getvalue())
|
||||
|
||||
def test_display_logs(self):
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command._display_logs(
|
||||
[{"timestamp": "2023-01-01", "level": "INFO", "message": "Test log"}]
|
||||
)
|
||||
self.assertIn("2023-01-01 - INFO: Test log", fake_out.getvalue())
|
||||
|
||||
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
|
||||
def test_deploy_with_uuid(self, mock_display):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_uuid.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy(uuid="test-uuid")
|
||||
|
||||
self.mock_client.deploy_by_uuid.assert_called_once_with("test-uuid")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
|
||||
@patch("crewai_cli.deploy.main.DeployCommand._display_deployment_info")
|
||||
def test_deploy_with_project_name(self, mock_display):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"uuid": "test-uuid"}
|
||||
self.mock_client.deploy_by_name.return_value = mock_response
|
||||
|
||||
self.deploy_command.deploy()
|
||||
|
||||
self.mock_client.deploy_by_name.assert_called_once_with("test_project")
|
||||
mock_display.assert_called_once_with({"uuid": "test-uuid"})
|
||||
|
||||
@patch("crewai_cli.deploy.main.fetch_and_json_env_file")
|
||||
@patch("crewai_cli.deploy.main.git.Repository.origin_url")
|
||||
@patch("builtins.input")
|
||||
def test_create_crew(self, mock_input, mock_git_origin_url, mock_fetch_env):
|
||||
mock_fetch_env.return_value = {"ENV_VAR": "value"}
|
||||
mock_git_origin_url.return_value = "https://github.com/test/repo.git"
|
||||
mock_input.return_value = ""
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 201
|
||||
mock_response.json.return_value = {"uuid": "new-uuid", "status": "created"}
|
||||
self.mock_client.create_crew.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.create_crew()
|
||||
self.assertIn("Deployment created successfully!", fake_out.getvalue())
|
||||
self.assertIn("new-uuid", fake_out.getvalue())
|
||||
|
||||
def test_list_crews(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Crew1", "uuid": "uuid1", "status": "active"},
|
||||
{"name": "Crew2", "uuid": "uuid2", "status": "inactive"},
|
||||
]
|
||||
self.mock_client.list_crews.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.list_crews()
|
||||
self.assertIn("Crew1 (uuid1) active", fake_out.getvalue())
|
||||
self.assertIn("Crew2 (uuid2) inactive", fake_out.getvalue())
|
||||
|
||||
def test_get_crew_status(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"name": "InternalCrew", "status": "active"}
|
||||
self.mock_client.crew_status_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_status()
|
||||
self.assertIn("InternalCrew", fake_out.getvalue())
|
||||
self.assertIn("active", fake_out.getvalue())
|
||||
|
||||
def test_get_crew_logs(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = [
|
||||
{"timestamp": "2023-01-01", "level": "INFO", "message": "Log1"},
|
||||
{"timestamp": "2023-01-02", "level": "ERROR", "message": "Log2"},
|
||||
]
|
||||
self.mock_client.crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.get_crew_logs(None)
|
||||
self.assertIn("2023-01-01 - INFO: Log1", fake_out.getvalue())
|
||||
self.assertIn("2023-01-02 - ERROR: Log2", fake_out.getvalue())
|
||||
|
||||
def test_remove_crew(self):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 204
|
||||
self.mock_client.delete_crew_by_name.return_value = mock_response
|
||||
|
||||
with patch("sys.stdout", new=StringIO()) as fake_out:
|
||||
self.deploy_command.remove_crew(None)
|
||||
self.assertIn(
|
||||
"Crew 'test_project' removed successfully", fake_out.getvalue()
|
||||
)
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
|
||||
def test_parse_toml_python_311_plus(self):
|
||||
toml_content = """
|
||||
[tool.poetry]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "^3.11"
|
||||
crewai = { extras = ["tools"], version = ">=0.51.0,<1.0.0" }
|
||||
"""
|
||||
parsed = parse_toml(toml_content)
|
||||
self.assertEqual(parsed["tool"]["poetry"]["name"], "test_project")
|
||||
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[project]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = ["crewai"]
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_310(self, mock_open):
|
||||
from crewai_cli.utils import get_project_name
|
||||
|
||||
project_name = get_project_name()
|
||||
print("project_name", project_name)
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
@unittest.skipIf(sys.version_info < (3, 11), "Requires Python 3.11+")
|
||||
@patch(
|
||||
"builtins.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data="""
|
||||
[project]
|
||||
name = "test_project"
|
||||
version = "0.1.0"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = ["crewai"]
|
||||
""",
|
||||
)
|
||||
def test_get_project_name_python_311_plus(self, mock_open):
|
||||
from crewai_cli.utils import get_project_name
|
||||
|
||||
project_name = get_project_name()
|
||||
self.assertEqual(project_name, "test_project")
|
||||
|
||||
def test_get_crewai_version(self):
|
||||
from crewai_cli.version import get_crewai_version
|
||||
|
||||
assert isinstance(get_crewai_version(), str)
|
||||
0
lib/cli/tests/enterprise/__init__.py
Normal file
0
lib/cli/tests/enterprise/__init__.py
Normal file
158
lib/cli/tests/enterprise/test_main.py
Normal file
158
lib/cli/tests/enterprise/test_main.py
Normal file
@@ -0,0 +1,158 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from crewai_cli.enterprise.main import EnterpriseConfigureCommand
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
import shutil
|
||||
|
||||
|
||||
class TestEnterpriseConfigureCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
|
||||
with patch('crewai_cli.enterprise.main.SettingsCommand') as mock_settings_command_class:
|
||||
self.mock_settings_command = Mock(spec=SettingsCommand)
|
||||
mock_settings_command_class.return_value = self.mock_settings_command
|
||||
|
||||
self.enterprise_command = EnterpriseConfigureCommand()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@patch('crewai_cli.enterprise.main.httpx.get')
|
||||
@patch('crewai_cli.enterprise.main.get_crewai_version')
|
||||
def test_successful_configuration(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos',
|
||||
'extra': {}
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
enterprise_url = "https://enterprise.example.com"
|
||||
self.enterprise_command.configure(enterprise_url)
|
||||
|
||||
expected_headers = {
|
||||
"Content-Type": "application/json",
|
||||
"User-Agent": "CrewAI-CLI/1.0.0",
|
||||
"X-Crewai-Version": "1.0.0",
|
||||
}
|
||||
mock_requests_get.assert_called_once_with(
|
||||
"https://enterprise.example.com/auth/parameters",
|
||||
timeout=30,
|
||||
headers=expected_headers
|
||||
)
|
||||
|
||||
expected_calls = [
|
||||
('enterprise_base_url', 'https://enterprise.example.com'),
|
||||
('oauth2_provider', 'workos'),
|
||||
('oauth2_audience', 'test_audience'),
|
||||
('oauth2_client_id', 'test_client_id'),
|
||||
('oauth2_domain', 'test.domain.com'),
|
||||
('oauth2_extra', {})
|
||||
]
|
||||
|
||||
actual_calls = self.mock_settings_command.set.call_args_list
|
||||
self.assertEqual(len(actual_calls), 6)
|
||||
|
||||
for i, (key, value) in enumerate(expected_calls):
|
||||
call_args = actual_calls[i][0]
|
||||
self.assertEqual(call_args[0], key)
|
||||
self.assertEqual(call_args[1], value)
|
||||
|
||||
@patch('crewai_cli.enterprise.main.httpx.get')
|
||||
@patch('crewai_cli.enterprise.main.get_crewai_version')
|
||||
def test_http_error_handling(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.raise_for_status.side_effect = httpx.HTTPStatusError(
|
||||
"404 Not Found",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(404),
|
||||
)
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai_cli.enterprise.main.httpx.get')
|
||||
@patch('crewai_cli.enterprise.main.get_crewai_version')
|
||||
def test_invalid_json_response(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Invalid JSON", "", 0)
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai_cli.enterprise.main.httpx.get')
|
||||
@patch('crewai_cli.enterprise.main.get_crewai_version')
|
||||
def test_missing_required_fields(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
@patch('crewai_cli.enterprise.main.httpx.get')
|
||||
@patch('crewai_cli.enterprise.main.get_crewai_version')
|
||||
def test_settings_update_error(self, mock_get_version, mock_requests_get):
|
||||
mock_get_version.return_value = "1.0.0"
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.raise_for_status.return_value = None
|
||||
mock_response.json.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
}
|
||||
mock_requests_get.return_value = mock_response
|
||||
|
||||
self.mock_settings_command.set.side_effect = Exception("Settings update failed")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.enterprise_command.configure("https://enterprise.example.com")
|
||||
|
||||
def test_url_trailing_slash_removal(self):
|
||||
with patch.object(self.enterprise_command, '_fetch_oauth_config') as mock_fetch, \
|
||||
patch.object(self.enterprise_command, '_update_oauth_settings') as mock_update:
|
||||
|
||||
mock_fetch.return_value = {
|
||||
'audience': 'test_audience',
|
||||
'domain': 'test.domain.com',
|
||||
'device_authorization_client_id': 'test_client_id',
|
||||
'provider': 'workos'
|
||||
}
|
||||
|
||||
self.enterprise_command.configure("https://enterprise.example.com/")
|
||||
|
||||
mock_fetch.assert_called_once_with("https://enterprise.example.com")
|
||||
mock_update.assert_called_once()
|
||||
1
lib/cli/tests/organization/__init__.py
Normal file
1
lib/cli/tests/organization/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
239
lib/cli/tests/organization/test_main.py
Normal file
239
lib/cli/tests/organization/test_main.py
Normal file
@@ -0,0 +1,239 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch, call
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
import httpx
|
||||
|
||||
from crewai_cli.organization.main import OrganizationCommand
|
||||
from crewai_cli.cli import org_list, switch, current
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def org_command():
|
||||
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||
command = OrganizationCommand()
|
||||
yield command
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_settings():
|
||||
with patch("crewai_cli.organization.main.Settings") as mock_settings_class:
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_class.return_value = mock_settings_instance
|
||||
yield mock_settings_instance
|
||||
|
||||
|
||||
@patch("crewai_cli.cli.OrganizationCommand")
|
||||
def test_org_list_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
|
||||
result = runner.invoke(org_list)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_org_command_class.assert_called_once()
|
||||
mock_org_instance.list.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai_cli.cli.OrganizationCommand")
|
||||
def test_org_switch_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
|
||||
result = runner.invoke(switch, ["test-id"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_org_command_class.assert_called_once()
|
||||
mock_org_instance.switch.assert_called_once_with("test-id")
|
||||
|
||||
|
||||
@patch("crewai_cli.cli.OrganizationCommand")
|
||||
def test_org_current_command(mock_org_command_class, runner):
|
||||
mock_org_instance = MagicMock()
|
||||
mock_org_command_class.return_value = mock_org_instance
|
||||
|
||||
result = runner.invoke(current)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_org_command_class.assert_called_once()
|
||||
mock_org_instance.current.assert_called_once()
|
||||
|
||||
|
||||
class TestOrganizationCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
with patch.object(OrganizationCommand, "__init__", return_value=None):
|
||||
self.org_command = OrganizationCommand()
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
@patch("crewai_cli.organization.main.Table")
|
||||
def test_list_organizations_success(self, mock_table, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Org 2", "uuid": "org-456"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
mock_console.print = MagicMock()
|
||||
|
||||
self.org_command.list()
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_table.assert_called_once_with(title="Your Organizations")
|
||||
mock_table.return_value.add_column.assert_has_calls(
|
||||
[call("Name", style="cyan"), call("ID", style="green")]
|
||||
)
|
||||
mock_table.return_value.add_row.assert_has_calls(
|
||||
[call("Org 1", "org-123"), call("Org 2", "org-456")]
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
def test_list_organizations_empty(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = []
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
self.org_command.list()
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You don't belong to any organizations yet.", style="yellow"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
def test_list_organizations_api_error(self, mock_console):
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.side_effect = (
|
||||
httpx.HTTPError("API Error")
|
||||
)
|
||||
|
||||
with pytest.raises(SystemExit):
|
||||
self.org_command.list()
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Failed to retrieve organization list: API Error", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
@patch("crewai_cli.organization.main.Settings")
|
||||
def test_switch_organization_success(self, mock_settings_class, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Test Org", "uuid": "test-id"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_class.return_value = mock_settings_instance
|
||||
|
||||
self.org_command.switch("test-id")
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_settings_instance.dump.assert_called_once()
|
||||
assert mock_settings_instance.org_name == "Test Org"
|
||||
assert mock_settings_instance.org_uuid == "test-id"
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Successfully switched to Test Org (test-id)", style="bold green"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
def test_switch_organization_not_found(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"name": "Org 1", "uuid": "org-123"},
|
||||
{"name": "Org 2", "uuid": "org-456"},
|
||||
]
|
||||
self.org_command.plus_api_client = MagicMock()
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
self.org_command.switch("non-existent-id")
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Organization with id 'non-existent-id' not found.", style="bold red"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
@patch("crewai_cli.organization.main.Settings")
|
||||
def test_current_organization_with_org(self, mock_settings_class, mock_console):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_name = "Test Org"
|
||||
mock_settings_instance.org_uuid = "test-id"
|
||||
mock_settings_class.return_value = mock_settings_instance
|
||||
|
||||
self.org_command.current()
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_not_called()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"Currently logged in to organization Test Org (test-id)", style="bold green"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
@patch("crewai_cli.organization.main.Settings")
|
||||
def test_current_organization_without_org(self, mock_settings_class, mock_console):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_uuid = None
|
||||
mock_settings_class.return_value = mock_settings_instance
|
||||
|
||||
self.org_command.current()
|
||||
|
||||
assert mock_console.print.call_count == 3
|
||||
mock_console.print.assert_any_call(
|
||||
"You're not currently logged in to any organization.", style="yellow"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
def test_list_organizations_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = httpx.HTTPStatusError(
|
||||
"401 Client Error: Unauthorized",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
self.org_command.list()
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red",
|
||||
)
|
||||
|
||||
@patch("crewai_cli.organization.main.console")
|
||||
def test_switch_organization_unauthorized(self, mock_console):
|
||||
mock_response = MagicMock()
|
||||
mock_http_error = httpx.HTTPStatusError(
|
||||
"401 Client Error: Unauthorized",
|
||||
request=httpx.Request("GET", "http://test"),
|
||||
response=httpx.Response(401),
|
||||
)
|
||||
|
||||
mock_response.raise_for_status.side_effect = mock_http_error
|
||||
self.org_command.plus_api_client.get_organizations.return_value = mock_response
|
||||
|
||||
self.org_command.switch("test-id")
|
||||
|
||||
self.org_command.plus_api_client.get_organizations.assert_called_once()
|
||||
mock_console.print.assert_called_once_with(
|
||||
"You are not logged in to any organization. Use 'crewai login' to login.",
|
||||
style="bold red",
|
||||
)
|
||||
255
lib/cli/tests/test_cli.py
Normal file
255
lib/cli/tests/test_cli.py
Normal file
@@ -0,0 +1,255 @@
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from crewai_cli.cli import (
|
||||
deploy_create,
|
||||
deploy_list,
|
||||
deploy_logs,
|
||||
deploy_push,
|
||||
deploy_remove,
|
||||
deply_status,
|
||||
flow_add_crew,
|
||||
login,
|
||||
reset_memories,
|
||||
test,
|
||||
train,
|
||||
version,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_default_iterations(train_crew, runner):
|
||||
result = runner.invoke(train)
|
||||
|
||||
train_crew.assert_called_once_with(5, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the Crew for 5 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_custom_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "10"])
|
||||
|
||||
train_crew.assert_called_once_with(10, "trained_agents_data.pkl")
|
||||
assert result.exit_code == 0
|
||||
assert "Training the Crew for 10 iterations" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.train_crew")
|
||||
def test_train_invalid_string_iterations(train_crew, runner):
|
||||
result = runner.invoke(train, ["--n_iterations", "invalid"])
|
||||
|
||||
train_crew.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert (
|
||||
"Usage: train [OPTIONS]\nTry 'train --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
|
||||
in result.output
|
||||
)
|
||||
|
||||
|
||||
def test_reset_no_memory_flags(runner):
|
||||
result = runner.invoke(
|
||||
reset_memories,
|
||||
)
|
||||
assert (
|
||||
result.output
|
||||
== "Please specify at least one memory type to reset using the appropriate flags.\n"
|
||||
)
|
||||
|
||||
|
||||
def test_version_flag(runner):
|
||||
result = runner.invoke(version)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
|
||||
|
||||
def test_version_command(runner):
|
||||
result = runner.invoke(version)
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
|
||||
|
||||
def test_version_command_with_tools(runner):
|
||||
result = runner.invoke(version, ["--tools"])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "crewai version:" in result.output
|
||||
assert (
|
||||
"crewai tools version:" in result.output
|
||||
or "crewai tools not installed" in result.output
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_default_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test)
|
||||
|
||||
evaluate_crew.assert_called_once_with(3, "gpt-4o-mini")
|
||||
assert result.exit_code == 0
|
||||
assert "Testing the crew for 3 iterations with model gpt-4o-mini" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_custom_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test, ["--n_iterations", "5", "--model", "gpt-4o"])
|
||||
|
||||
evaluate_crew.assert_called_once_with(5, "gpt-4o")
|
||||
assert result.exit_code == 0
|
||||
assert "Testing the crew for 5 iterations with model gpt-4o" in result.output
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.evaluate_crew")
|
||||
def test_test_invalid_string_iterations(evaluate_crew, runner):
|
||||
result = runner.invoke(test, ["--n_iterations", "invalid"])
|
||||
|
||||
evaluate_crew.assert_not_called()
|
||||
assert result.exit_code == 2
|
||||
assert (
|
||||
"Usage: test [OPTIONS]\nTry 'test --help' for help.\n\nError: Invalid value for '-n' / '--n_iterations': 'invalid' is not a valid integer.\n"
|
||||
in result.output
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.AuthenticationCommand")
|
||||
def test_login(command, runner):
|
||||
mock_auth = command.return_value
|
||||
result = runner.invoke(login)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_auth.login.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_create(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_create)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.create_crew.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_list(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_list)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.list_crews.assert_called_once()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_push(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_push, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_push_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_push)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.deploy.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_status(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deply_status, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_status.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_status_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deply_status)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_status.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_logs(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_logs, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_logs.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_logs_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_logs)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.get_crew_logs.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_remove(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
uuid = "test-uuid"
|
||||
result = runner.invoke(deploy_remove, ["-u", uuid])
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.remove_crew.assert_called_once_with(uuid=uuid)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.cli.DeployCommand")
|
||||
def test_deploy_remove_no_uuid(command, runner):
|
||||
mock_deploy = command.return_value
|
||||
result = runner.invoke(deploy_remove)
|
||||
|
||||
assert result.exit_code == 0
|
||||
mock_deploy.remove_crew.assert_called_once_with(uuid=None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.add_crew_to_flow.create_embedded_crew")
|
||||
@mock.patch("pathlib.Path.exists", return_value=True)
|
||||
def test_flow_add_crew(mock_path_exists, mock_create_embedded_crew, runner):
|
||||
crew_name = "new_crew"
|
||||
result = runner.invoke(flow_add_crew, [crew_name])
|
||||
|
||||
assert result.exit_code == 0, f"Command failed with output: {result.output}"
|
||||
assert f"Adding crew {crew_name} to the flow" in result.output
|
||||
|
||||
mock_create_embedded_crew.assert_called_once()
|
||||
call_args, call_kwargs = mock_create_embedded_crew.call_args
|
||||
assert call_args[0] == crew_name
|
||||
assert "parent_folder" in call_kwargs
|
||||
assert isinstance(call_kwargs["parent_folder"], Path)
|
||||
|
||||
|
||||
def test_add_crew_to_flow_not_in_root(runner):
|
||||
with mock.patch("pathlib.Path.exists", autospec=True) as mock_exists:
|
||||
def exists_side_effect(self):
|
||||
if self.name == "pyproject.toml":
|
||||
return False
|
||||
return True
|
||||
|
||||
mock_exists.side_effect = exists_side_effect
|
||||
|
||||
result = runner.invoke(flow_add_crew, ["new_crew"])
|
||||
|
||||
assert result.exit_code != 0
|
||||
assert "This command must be run from the root of a flow project." in str(
|
||||
result.output
|
||||
)
|
||||
148
lib/cli/tests/test_config.py
Normal file
148
lib/cli/tests/test_config.py
Normal file
@@ -0,0 +1,148 @@
|
||||
import json
|
||||
import shutil
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_cli.config import (
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
USER_SETTINGS_KEYS,
|
||||
Settings,
|
||||
)
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestSettings(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
def test_empty_initialization(self):
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
self.assertIsNone(settings.tool_repository_password)
|
||||
|
||||
def test_initialization_with_data(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
self.assertEqual(settings.tool_repository_username, "user1")
|
||||
self.assertIsNone(settings.tool_repository_password)
|
||||
|
||||
def test_initialization_with_existing_file(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump({"tool_repository_username": "file_user"}, f)
|
||||
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertEqual(settings.tool_repository_username, "file_user")
|
||||
|
||||
def test_merge_file_and_input_data(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump(
|
||||
{
|
||||
"tool_repository_username": "file_user",
|
||||
"tool_repository_password": "file_pass",
|
||||
},
|
||||
f,
|
||||
)
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="new_user"
|
||||
)
|
||||
self.assertEqual(settings.tool_repository_username, "new_user")
|
||||
self.assertEqual(settings.tool_repository_password, "file_pass")
|
||||
|
||||
def test_clear_user_settings(self):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
|
||||
settings = Settings(config_path=self.config_path, **user_settings)
|
||||
settings.clear_user_settings()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
|
||||
@patch("crewai_cli.config.TokenManager")
|
||||
def test_reset_settings(self, mock_token_manager):
|
||||
user_settings = {key: f"value_for_{key}" for key in USER_SETTINGS_KEYS}
|
||||
cli_settings = {key: f"value_for_{key}" for key in CLI_SETTINGS_KEYS if key != "oauth2_extra"}
|
||||
cli_settings["oauth2_extra"] = {"scope": "xxx", "other": "yyy"}
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, **user_settings, **cli_settings
|
||||
)
|
||||
|
||||
mock_token_manager.return_value = MagicMock()
|
||||
TokenManager().save_tokens(
|
||||
"aaa.bbb.ccc", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
|
||||
settings.reset()
|
||||
|
||||
for key in user_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), None)
|
||||
for key in cli_settings.keys():
|
||||
self.assertEqual(getattr(settings, key), DEFAULT_CLI_SETTINGS.get(key))
|
||||
|
||||
mock_token_manager.return_value.clear_tokens.assert_called_once()
|
||||
|
||||
def test_dump_new_settings(self):
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertEqual(saved_data["tool_repository_username"], "user1")
|
||||
|
||||
def test_update_existing_settings(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
json.dump({"existing_setting": "value"}, f)
|
||||
|
||||
settings = Settings(
|
||||
config_path=self.config_path, tool_repository_username="user1"
|
||||
)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertEqual(saved_data["existing_setting"], "value")
|
||||
self.assertEqual(saved_data["tool_repository_username"], "user1")
|
||||
|
||||
def test_none_values(self):
|
||||
settings = Settings(config_path=self.config_path, tool_repository_username=None)
|
||||
settings.dump()
|
||||
|
||||
with self.config_path.open("r") as f:
|
||||
saved_data = json.load(f)
|
||||
|
||||
self.assertIsNone(saved_data.get("tool_repository_username"))
|
||||
|
||||
def test_invalid_json_in_config(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with self.config_path.open("w") as f:
|
||||
f.write("invalid json")
|
||||
|
||||
try:
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
except json.JSONDecodeError:
|
||||
self.fail("Settings initialization should handle invalid JSON")
|
||||
|
||||
def test_empty_config_file(self):
|
||||
self.config_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self.config_path.touch()
|
||||
|
||||
settings = Settings(config_path=self.config_path)
|
||||
self.assertIsNone(settings.tool_repository_username)
|
||||
20
lib/cli/tests/test_constants.py
Normal file
20
lib/cli/tests/test_constants.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from crewai_cli.constants import ENV_VARS, MODELS, PROVIDERS
|
||||
|
||||
|
||||
def test_huggingface_in_providers():
|
||||
"""Test that Huggingface is in the PROVIDERS list."""
|
||||
assert "huggingface" in PROVIDERS
|
||||
|
||||
|
||||
def test_huggingface_env_vars():
|
||||
"""Test that Huggingface environment variables are properly configured."""
|
||||
assert "huggingface" in ENV_VARS
|
||||
assert any(
|
||||
detail.get("key_name") == "HF_TOKEN" for detail in ENV_VARS["huggingface"]
|
||||
)
|
||||
|
||||
|
||||
def test_huggingface_models():
|
||||
"""Test that Huggingface models are properly configured."""
|
||||
assert "huggingface" in MODELS
|
||||
assert len(MODELS["huggingface"]) > 0
|
||||
347
lib/cli/tests/test_create_crew.py
Normal file
347
lib/cli/tests/test_create_crew.py
Normal file
@@ -0,0 +1,347 @@
|
||||
import keyword
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from crewai_cli.create_crew import create_crew, create_folder_structure
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def runner():
|
||||
return CliRunner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_dir():
|
||||
temp_path = tempfile.mkdtemp()
|
||||
yield temp_path
|
||||
shutil.rmtree(temp_path)
|
||||
|
||||
|
||||
def test_create_folder_structure_strips_single_trailing_slash():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"hello/", parent_folder=temp_dir
|
||||
)
|
||||
|
||||
assert folder_name == "hello"
|
||||
assert class_name == "Hello"
|
||||
assert folder_path.name == "hello"
|
||||
assert folder_path.exists()
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_strips_multiple_trailing_slashes():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"hello///", parent_folder=temp_dir
|
||||
)
|
||||
|
||||
assert folder_name == "hello"
|
||||
assert class_name == "Hello"
|
||||
assert folder_path.name == "hello"
|
||||
assert folder_path.exists()
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_handles_complex_name_with_trailing_slash():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"my-awesome_project/", parent_folder=temp_dir
|
||||
)
|
||||
|
||||
assert folder_name == "my_awesome_project"
|
||||
assert class_name == "MyAwesomeProject"
|
||||
assert folder_path.name == "my_awesome_project"
|
||||
assert folder_path.exists()
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_normal_name_unchanged():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"hello", parent_folder=temp_dir
|
||||
)
|
||||
|
||||
assert folder_name == "hello"
|
||||
assert class_name == "Hello"
|
||||
assert folder_path.name == "hello"
|
||||
assert folder_path.exists()
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_with_parent_folder():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
parent_path = Path(temp_dir) / "parent"
|
||||
parent_path.mkdir()
|
||||
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"child/", parent_folder=parent_path
|
||||
)
|
||||
|
||||
assert folder_name == "child"
|
||||
assert class_name == "Child"
|
||||
assert folder_path.name == "child"
|
||||
assert folder_path.parent == parent_path
|
||||
assert folder_path.exists()
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.create_crew.copy_template")
|
||||
@mock.patch("crewai_cli.create_crew.write_env_file")
|
||||
@mock.patch("crewai_cli.create_crew.load_env_vars")
|
||||
def test_create_crew_with_trailing_slash_creates_valid_project(
|
||||
mock_load_env, mock_write_env, mock_copy_template, temp_dir
|
||||
):
|
||||
mock_load_env.return_value = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
with mock.patch(
|
||||
"crewai_cli.create_crew.create_folder_structure"
|
||||
) as mock_create_folder:
|
||||
mock_folder_path = Path(work_dir) / "test_project"
|
||||
mock_create_folder.return_value = (
|
||||
mock_folder_path,
|
||||
"test_project",
|
||||
"TestProject",
|
||||
)
|
||||
|
||||
create_crew("test-project/", skip_provider=True)
|
||||
|
||||
mock_create_folder.assert_called_once_with("test-project/", None)
|
||||
mock_copy_template.assert_called()
|
||||
copy_calls = mock_copy_template.call_args_list
|
||||
|
||||
for call in copy_calls:
|
||||
args = call[0]
|
||||
if len(args) >= 5:
|
||||
folder_name_arg = args[4]
|
||||
assert not folder_name_arg.endswith("/"), (
|
||||
f"folder_name should not end with slash: {folder_name_arg}"
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.create_crew.copy_template")
|
||||
@mock.patch("crewai_cli.create_crew.write_env_file")
|
||||
@mock.patch("crewai_cli.create_crew.load_env_vars")
|
||||
def test_create_crew_with_multiple_trailing_slashes(
|
||||
mock_load_env, mock_write_env, mock_copy_template, temp_dir
|
||||
):
|
||||
mock_load_env.return_value = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
with mock.patch(
|
||||
"crewai_cli.create_crew.create_folder_structure"
|
||||
) as mock_create_folder:
|
||||
mock_folder_path = Path(work_dir) / "test_project"
|
||||
mock_create_folder.return_value = (
|
||||
mock_folder_path,
|
||||
"test_project",
|
||||
"TestProject",
|
||||
)
|
||||
|
||||
create_crew("test-project///", skip_provider=True)
|
||||
|
||||
mock_create_folder.assert_called_once_with("test-project///", None)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.create_crew.copy_template")
|
||||
@mock.patch("crewai_cli.create_crew.write_env_file")
|
||||
@mock.patch("crewai_cli.create_crew.load_env_vars")
|
||||
def test_create_crew_normal_name_still_works(
|
||||
mock_load_env, mock_write_env, mock_copy_template, temp_dir
|
||||
):
|
||||
mock_load_env.return_value = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
with mock.patch(
|
||||
"crewai_cli.create_crew.create_folder_structure"
|
||||
) as mock_create_folder:
|
||||
mock_folder_path = Path(work_dir) / "normal_project"
|
||||
mock_create_folder.return_value = (
|
||||
mock_folder_path,
|
||||
"normal_project",
|
||||
"NormalProject",
|
||||
)
|
||||
|
||||
create_crew("normal-project", skip_provider=True)
|
||||
|
||||
mock_create_folder.assert_called_once_with("normal-project", None)
|
||||
|
||||
|
||||
def test_create_folder_structure_handles_spaces_and_dashes_with_slash():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
"My Cool-Project/", parent_folder=temp_dir
|
||||
)
|
||||
|
||||
assert folder_name == "my_cool_project"
|
||||
assert class_name == "MyCoolProject"
|
||||
assert folder_path.name == "my_cool_project"
|
||||
assert folder_path.exists()
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_raises_error_for_invalid_names():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
invalid_cases = [
|
||||
("123project/", "cannot start with a digit"),
|
||||
("True/", "reserved Python keyword"),
|
||||
("False/", "reserved Python keyword"),
|
||||
("None/", "reserved Python keyword"),
|
||||
("class/", "reserved Python keyword"),
|
||||
("def/", "reserved Python keyword"),
|
||||
(" /", "empty or contain only whitespace"),
|
||||
("", "empty or contain only whitespace"),
|
||||
("@#$/", "contains no valid characters"),
|
||||
]
|
||||
|
||||
for invalid_name, expected_error in invalid_cases:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
create_folder_structure(invalid_name, parent_folder=temp_dir)
|
||||
|
||||
|
||||
def test_create_folder_structure_validates_names():
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
valid_cases = [
|
||||
("hello/", "hello", "Hello"),
|
||||
("my-project/", "my_project", "MyProject"),
|
||||
("hello_world/", "hello_world", "HelloWorld"),
|
||||
("valid123/", "valid123", "Valid123"),
|
||||
("hello.world/", "helloworld", "HelloWorld"),
|
||||
("hello@world/", "helloworld", "HelloWorld"),
|
||||
]
|
||||
|
||||
for valid_name, expected_folder, expected_class in valid_cases:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
valid_name, parent_folder=temp_dir
|
||||
)
|
||||
assert folder_name == expected_folder
|
||||
assert class_name == expected_class
|
||||
|
||||
assert folder_name.isidentifier(), (
|
||||
f"folder_name '{folder_name}' should be valid Python identifier"
|
||||
)
|
||||
assert not keyword.iskeyword(folder_name), (
|
||||
f"folder_name '{folder_name}' should not be Python keyword"
|
||||
)
|
||||
assert not folder_name[0].isdigit(), (
|
||||
f"folder_name '{folder_name}' should not start with digit"
|
||||
)
|
||||
|
||||
assert class_name.isidentifier(), (
|
||||
f"class_name '{class_name}' should be valid Python identifier"
|
||||
)
|
||||
assert not keyword.iskeyword(class_name), (
|
||||
f"class_name '{class_name}' should not be Python keyword"
|
||||
)
|
||||
assert folder_path.parent == Path(temp_dir)
|
||||
|
||||
if folder_path.exists():
|
||||
shutil.rmtree(folder_path)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.create_crew.copy_template")
|
||||
@mock.patch("crewai_cli.create_crew.write_env_file")
|
||||
@mock.patch("crewai_cli.create_crew.load_env_vars")
|
||||
def test_create_crew_with_parent_folder_and_trailing_slash(
|
||||
mock_load_env, mock_write_env, mock_copy_template, temp_dir
|
||||
):
|
||||
mock_load_env.return_value = {}
|
||||
|
||||
with tempfile.TemporaryDirectory() as work_dir:
|
||||
parent_path = Path(work_dir) / "parent"
|
||||
parent_path.mkdir()
|
||||
|
||||
create_crew("child-crew/", skip_provider=True, parent_folder=parent_path)
|
||||
|
||||
crew_path = parent_path / "child_crew"
|
||||
assert crew_path.exists()
|
||||
assert not (crew_path / "src").exists()
|
||||
|
||||
|
||||
def test_create_folder_structure_folder_name_validation():
|
||||
"""Test that folder names are validated as valid Python module names"""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
folder_invalid_cases = [
|
||||
("123invalid/", "cannot start with a digit.*invalid Python module name"),
|
||||
("import/", "reserved Python keyword"),
|
||||
("class/", "reserved Python keyword"),
|
||||
("for/", "reserved Python keyword"),
|
||||
("@#$invalid/", "contains no valid characters.*Python module name"),
|
||||
]
|
||||
|
||||
for invalid_name, expected_error in folder_invalid_cases:
|
||||
with pytest.raises(ValueError, match=expected_error):
|
||||
create_folder_structure(invalid_name, parent_folder=temp_dir)
|
||||
|
||||
valid_cases = [
|
||||
("hello-world/", "hello_world"),
|
||||
("my.project/", "myproject"),
|
||||
("test@123/", "test123"),
|
||||
("valid_name/", "valid_name"),
|
||||
]
|
||||
|
||||
for valid_name, expected_folder in valid_cases:
|
||||
folder_path, folder_name, class_name = create_folder_structure(
|
||||
valid_name, parent_folder=temp_dir
|
||||
)
|
||||
assert folder_name == expected_folder
|
||||
assert folder_name.isidentifier()
|
||||
assert not keyword.iskeyword(folder_name)
|
||||
|
||||
if folder_path.exists():
|
||||
shutil.rmtree(folder_path)
|
||||
|
||||
|
||||
def test_create_folder_structure_rejects_reserved_names():
|
||||
"""Test that reserved script names are rejected to prevent pyproject.toml conflicts."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
reserved_names = ["test", "train", "replay", "run_crew", "run_with_trigger"]
|
||||
|
||||
for reserved_name in reserved_names:
|
||||
with pytest.raises(ValueError, match="which is reserved"):
|
||||
create_folder_structure(reserved_name, parent_folder=temp_dir)
|
||||
|
||||
with pytest.raises(ValueError, match="which is reserved"):
|
||||
create_folder_structure(f"{reserved_name}/", parent_folder=temp_dir)
|
||||
|
||||
capitalized = reserved_name.capitalize()
|
||||
with pytest.raises(ValueError, match="which is reserved"):
|
||||
create_folder_structure(capitalized, parent_folder=temp_dir)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.create_crew.create_folder_structure")
|
||||
@mock.patch("crewai_cli.create_crew.copy_template")
|
||||
@mock.patch("crewai_cli.create_crew.load_env_vars")
|
||||
@mock.patch("crewai_cli.create_crew.get_provider_data")
|
||||
@mock.patch("crewai_cli.create_crew.select_provider")
|
||||
@mock.patch("crewai_cli.create_crew.select_model")
|
||||
@mock.patch("click.prompt")
|
||||
def test_env_vars_are_uppercased_in_env_file(
|
||||
mock_prompt,
|
||||
mock_select_model,
|
||||
mock_select_provider,
|
||||
mock_get_provider_data,
|
||||
mock_load_env_vars,
|
||||
mock_copy_template,
|
||||
mock_create_folder_structure,
|
||||
tmp_path,
|
||||
):
|
||||
crew_path = tmp_path / "test_crew"
|
||||
crew_path.mkdir()
|
||||
mock_create_folder_structure.return_value = (crew_path, "test_crew", "TestCrew")
|
||||
|
||||
mock_load_env_vars.return_value = {}
|
||||
mock_get_provider_data.return_value = {"openai": ["gpt-4"]}
|
||||
mock_select_provider.return_value = "azure"
|
||||
mock_select_model.return_value = "azure/openai"
|
||||
mock_prompt.return_value = "fake-api-key"
|
||||
|
||||
create_crew("Test Crew")
|
||||
|
||||
env_file_path = crew_path / ".env"
|
||||
content = env_file_path.read_text()
|
||||
assert "MODEL=" in content
|
||||
97
lib/cli/tests/test_crew_test.py
Normal file
97
lib/cli/tests/test_crew_test.py
Normal file
@@ -0,0 +1,97 @@
|
||||
import subprocess
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_cli import evaluate_crew
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"n_iterations,model",
|
||||
[
|
||||
(1, "gpt-4o"),
|
||||
(5, "gpt-3.5-turbo"),
|
||||
(10, "gpt-4"),
|
||||
],
|
||||
)
|
||||
@mock.patch("crewai_cli.evaluate_crew.subprocess.run")
|
||||
def test_crew_success(mock_subprocess_run, n_iterations, model):
|
||||
"""Test the crew function for successful execution."""
|
||||
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
||||
args=f"uv run test {n_iterations} {model}", returncode=0
|
||||
)
|
||||
result = evaluate_crew.evaluate_crew(n_iterations, model)
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "test", str(n_iterations), model],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
assert result is None
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.evaluate_crew.click")
|
||||
def test_test_crew_zero_iterations(click):
|
||||
evaluate_crew.evaluate_crew(0, "gpt-4o")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.evaluate_crew.click")
|
||||
def test_test_crew_negative_iterations(click):
|
||||
evaluate_crew.evaluate_crew(-2, "gpt-4o")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.evaluate_crew.click")
|
||||
@mock.patch("crewai_cli.evaluate_crew.subprocess.run")
|
||||
def test_test_crew_called_process_error(mock_subprocess_run, click):
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["uv", "run", "test", str(n_iterations), "gpt-4o"],
|
||||
output="Error",
|
||||
stderr="Some error occurred",
|
||||
)
|
||||
evaluate_crew.evaluate_crew(n_iterations, "gpt-4o")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "test", "5", "gpt-4o"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_has_calls(
|
||||
[
|
||||
mock.call.echo(
|
||||
"An error occurred while testing the crew: Command '['uv', 'run', 'test', '5', 'gpt-4o']' returned non-zero exit status 1.",
|
||||
err=True,
|
||||
),
|
||||
mock.call.echo("Error", err=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.evaluate_crew.click")
|
||||
@mock.patch("crewai_cli.evaluate_crew.subprocess.run")
|
||||
def test_test_crew_unexpected_exception(mock_subprocess_run, click):
|
||||
# Arrange
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
||||
evaluate_crew.evaluate_crew(n_iterations, "gpt-4o")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "test", "5", "gpt-4o"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: Unexpected error", err=True
|
||||
)
|
||||
101
lib/cli/tests/test_git.py
Normal file
101
lib/cli/tests/test_git.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import pytest
|
||||
from crewai_cli.git import Repository
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def repository(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(["git", "rev-parse", "--is-inside-work-tree"], stdout="true\n")
|
||||
fp.register(["git", "fetch"], stdout="")
|
||||
return Repository(path=".")
|
||||
|
||||
|
||||
def test_init_with_invalid_git_repo(fp):
|
||||
fp.register(["git", "--version"], stdout="git version 2.30.0\n")
|
||||
fp.register(
|
||||
["git", "rev-parse", "--is-inside-work-tree"],
|
||||
returncode=1,
|
||||
stderr="fatal: not a git repository\n",
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
Repository(path="invalid/path")
|
||||
|
||||
|
||||
def test_is_git_not_installed(fp):
|
||||
fp.register(["git", "--version"], returncode=1)
|
||||
|
||||
with pytest.raises(
|
||||
ValueError, match="Git is not installed or not found in your PATH."
|
||||
):
|
||||
Repository(path=".")
|
||||
|
||||
|
||||
def test_status(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.status() == "## main...origin/main [ahead 1]"
|
||||
|
||||
|
||||
def test_has_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.has_uncommitted_changes() is True
|
||||
|
||||
|
||||
def test_is_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_ahead_or_behind() is True
|
||||
|
||||
|
||||
def test_is_synced_when_synced(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"], stdout="## main...origin/main\n"
|
||||
)
|
||||
assert repository.is_synced() is True
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_when_ahead_or_behind(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_is_synced_with_uncommitted_changes_and_ahead(fp, repository):
|
||||
fp.register(
|
||||
["git", "status", "--branch", "--porcelain"],
|
||||
stdout="## main...origin/main [ahead 1]\n M somefile.txt\n",
|
||||
)
|
||||
assert repository.is_synced() is False
|
||||
|
||||
|
||||
def test_origin_url(fp, repository):
|
||||
fp.register(
|
||||
["git", "remote", "get-url", "origin"],
|
||||
stdout="https://github.com/user/repo.git\n",
|
||||
)
|
||||
assert repository.origin_url() == "https://github.com/user/repo.git"
|
||||
356
lib/cli/tests/test_plus_api.py
Normal file
356
lib/cli/tests/test_plus_api.py
Normal file
@@ -0,0 +1,356 @@
|
||||
import os
|
||||
import unittest
|
||||
from unittest.mock import ANY, AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai_cli.plus_api import PlusAPI
|
||||
|
||||
|
||||
class TestPlusAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.api_key = "test_api_key"
|
||||
self.api = PlusAPI(self.api_key)
|
||||
self.org_uuid = "test-org-uuid"
|
||||
|
||||
def test_init(self):
|
||||
self.assertEqual(self.api.api_key, self.api_key)
|
||||
self.assertEqual(self.api.headers["Authorization"], f"Bearer {self.api_key}")
|
||||
self.assertEqual(self.api.headers["Content-Type"], "application/json")
|
||||
self.assertTrue("CrewAI-CLI/" in self.api.headers["User-Agent"])
|
||||
self.assertTrue(self.api.headers["X-Crewai-Version"])
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_login_to_tool_repository(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
def assert_request_with_org_id(
|
||||
self, mock_client_instance, method: str, endpoint: str, **kwargs
|
||||
):
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
method,
|
||||
f"{os.getenv('CREWAI_PLUS_URL')}{endpoint}",
|
||||
headers={
|
||||
"Authorization": ANY,
|
||||
"Content-Type": ANY,
|
||||
"User-Agent": ANY,
|
||||
"X-Crewai-Version": ANY,
|
||||
"X-Crewai-Organization-Id": self.org_uuid,
|
||||
},
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_login_to_tool_repository_with_org_uuid(
|
||||
self, mock_client_class, mock_settings_class
|
||||
):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.login_to_tool_repository()
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools/login"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_get_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_get_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api.get_tool("test_tool_handle")
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "GET", "/crewai_plus/api/v1/tools/test_tool_handle"
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_publish_tool_with_org_uuid(self, mock_client_class, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = self.org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv('CREWAI_PLUS_URL')
|
||||
mock_settings_class.return_value = mock_settings
|
||||
self.api = PlusAPI(self.api_key)
|
||||
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
expected_params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
mock_client_instance, "POST", "/crewai_plus/api/v1/tools", json=expected_params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_without_description(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = False
|
||||
version = "2.0.0"
|
||||
description = None
|
||||
encoded_file = "encoded_test_file"
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.httpx.Client")
|
||||
def test_make_request(self, mock_client_class):
|
||||
mock_client_instance = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance.request.return_value = mock_response
|
||||
mock_client_class.return_value.__enter__.return_value = mock_client_instance
|
||||
|
||||
response = self.api._make_request("GET", "test_endpoint")
|
||||
|
||||
mock_client_class.assert_called_once_with(trust_env=False, verify=True)
|
||||
mock_client_instance.request.assert_called_once_with(
|
||||
"GET", f"{self.api.base_url}/test_endpoint", headers=self.api.headers
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_name(self, mock_make_request):
|
||||
self.api.deploy_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/by-name/test_project/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_deploy_by_uuid(self, mock_make_request):
|
||||
self.api.deploy_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews/test_uuid/deploy"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_name(self, mock_make_request):
|
||||
self.api.crew_status_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/status"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_status_by_uuid(self, mock_make_request):
|
||||
self.api.crew_status_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/status"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_name(self, mock_make_request):
|
||||
self.api.crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_name("test_project", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/by-name/test_project/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_crew_by_uuid(self, mock_make_request):
|
||||
self.api.crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/deployment"
|
||||
)
|
||||
|
||||
self.api.crew_by_uuid("test_uuid", "custom_log")
|
||||
mock_make_request.assert_called_with(
|
||||
"GET", "/crewai_plus/api/v1/crews/test_uuid/logs/custom_log"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_name(self, mock_make_request):
|
||||
self.api.delete_crew_by_name("test_project")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/by-name/test_project"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_delete_crew_by_uuid(self, mock_make_request):
|
||||
self.api.delete_crew_by_uuid("test_uuid")
|
||||
mock_make_request.assert_called_once_with(
|
||||
"DELETE", "/crewai_plus/api/v1/crews/test_uuid"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_list_crews(self, mock_make_request):
|
||||
self.api.list_crews()
|
||||
mock_make_request.assert_called_once_with("GET", "/crewai_plus/api/v1/crews")
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI._make_request")
|
||||
def test_create_crew(self, mock_make_request):
|
||||
payload = {"name": "test_crew"}
|
||||
self.api.create_crew(payload)
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/crews", json=payload
|
||||
)
|
||||
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
@patch.dict(os.environ, {"CREWAI_PLUS_URL": ""})
|
||||
def test_custom_base_url(self, mock_settings_class):
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.enterprise_base_url = "https://custom-url.com/api"
|
||||
mock_settings_class.return_value = mock_settings
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url.com/api",
|
||||
)
|
||||
|
||||
@patch.dict(os.environ, {"CREWAI_PLUS_URL": "https://custom-url-from-env.com"})
|
||||
def test_custom_base_url_from_env(self):
|
||||
custom_api = PlusAPI("test_key")
|
||||
self.assertEqual(
|
||||
custom_api.base_url,
|
||||
"https://custom-url-from-env.com",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
async def test_get_agent(mock_async_client_class):
|
||||
api = PlusAPI("test_api_key")
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert response == mock_response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("httpx.AsyncClient")
|
||||
@patch("crewai_cli.plus_api.Settings")
|
||||
async def test_get_agent_with_org_uuid(mock_settings_class, mock_async_client_class):
|
||||
org_uuid = "test-org-uuid"
|
||||
mock_settings = MagicMock()
|
||||
mock_settings.org_uuid = org_uuid
|
||||
mock_settings.enterprise_base_url = os.getenv("CREWAI_PLUS_URL")
|
||||
mock_settings_class.return_value = mock_settings
|
||||
|
||||
api = PlusAPI("test_api_key")
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_async_client_class.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
response = await api.get_agent("test_agent_handle")
|
||||
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
f"{api.base_url}/crewai_plus/api/v1/agents/test_agent_handle",
|
||||
headers=api.headers,
|
||||
)
|
||||
assert "X-Crewai-Organization-Id" in api.headers
|
||||
assert api.headers["X-Crewai-Organization-Id"] == org_uuid
|
||||
assert response == mock_response
|
||||
90
lib/cli/tests/test_settings_command.py
Normal file
90
lib/cli/tests/test_settings_command.py
Normal file
@@ -0,0 +1,90 @@
|
||||
import tempfile
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock, call
|
||||
|
||||
from crewai_cli.settings.main import SettingsCommand
|
||||
from crewai_cli.config import (
|
||||
Settings,
|
||||
USER_SETTINGS_KEYS,
|
||||
CLI_SETTINGS_KEYS,
|
||||
DEFAULT_CLI_SETTINGS,
|
||||
HIDDEN_SETTINGS_KEYS,
|
||||
READONLY_SETTINGS_KEYS,
|
||||
)
|
||||
import shutil
|
||||
|
||||
|
||||
class TestSettingsCommand(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.test_dir = Path(tempfile.mkdtemp())
|
||||
self.config_path = self.test_dir / "settings.json"
|
||||
self.settings = Settings(config_path=self.config_path)
|
||||
self.settings_command = SettingsCommand(
|
||||
settings_kwargs={"config_path": self.config_path}
|
||||
)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
|
||||
@patch("crewai_cli.settings.main.console")
|
||||
@patch("crewai_cli.settings.main.Table")
|
||||
def test_list_settings(self, mock_table_class, mock_console):
|
||||
mock_table_instance = MagicMock()
|
||||
mock_table_class.return_value = mock_table_instance
|
||||
|
||||
self.settings_command.list()
|
||||
|
||||
# Tests that the table is created skipping hidden settings
|
||||
mock_table_instance.add_row.assert_has_calls(
|
||||
[
|
||||
call(
|
||||
field_name,
|
||||
getattr(self.settings, field_name) or "Not set",
|
||||
field_info.description,
|
||||
)
|
||||
for field_name, field_info in Settings.model_fields.items()
|
||||
if field_name not in HIDDEN_SETTINGS_KEYS
|
||||
]
|
||||
)
|
||||
|
||||
# Tests that the table is printed
|
||||
mock_console.print.assert_called_once_with(mock_table_instance)
|
||||
|
||||
def test_set_valid_keys(self):
|
||||
valid_keys = Settings.model_fields.keys() - (
|
||||
READONLY_SETTINGS_KEYS + HIDDEN_SETTINGS_KEYS
|
||||
)
|
||||
for key in valid_keys:
|
||||
test_value = f"some_value_for_{key}"
|
||||
self.settings_command.set(key, test_value)
|
||||
self.assertEqual(getattr(self.settings_command.settings, key), test_value)
|
||||
|
||||
def test_set_invalid_key(self):
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set("invalid_key", "value")
|
||||
|
||||
def test_set_readonly_keys(self):
|
||||
for key in READONLY_SETTINGS_KEYS:
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set(key, "some_readonly_key_value")
|
||||
|
||||
def test_set_hidden_keys(self):
|
||||
for key in HIDDEN_SETTINGS_KEYS:
|
||||
with self.assertRaises(SystemExit):
|
||||
self.settings_command.set(key, "some_hidden_key_value")
|
||||
|
||||
def test_reset_all_settings(self):
|
||||
for key in USER_SETTINGS_KEYS + CLI_SETTINGS_KEYS:
|
||||
setattr(self.settings_command.settings, key, f"custom_value_for_{key}")
|
||||
self.settings_command.settings.dump()
|
||||
|
||||
self.settings_command.reset_all_settings()
|
||||
|
||||
for key in USER_SETTINGS_KEYS:
|
||||
self.assertEqual(getattr(self.settings_command.settings, key), None)
|
||||
|
||||
for key in CLI_SETTINGS_KEYS:
|
||||
self.assertEqual(
|
||||
getattr(self.settings_command.settings, key), DEFAULT_CLI_SETTINGS.get(key)
|
||||
)
|
||||
294
lib/cli/tests/test_token_manager.py
Normal file
294
lib/cli/tests/test_token_manager.py
Normal file
@@ -0,0 +1,294 @@
|
||||
"""Tests for TokenManager with atomic file operations."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
"""Test cases for TokenManager."""
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def setUp(self, mock_get_key: unittest.mock.MagicMock) -> None:
|
||||
"""Set up test fixtures."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
self.token_manager = TokenManager()
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_get_or_create_key_existing(
|
||||
self,
|
||||
mock_get_or_create: unittest.mock.MagicMock,
|
||||
mock_read: unittest.mock.MagicMock,
|
||||
) -> None:
|
||||
"""Test that existing key is returned when present."""
|
||||
mock_key = Fernet.generate_key()
|
||||
mock_get_or_create.return_value = mock_key
|
||||
|
||||
token_manager = TokenManager()
|
||||
result = token_manager.key
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
|
||||
def test_get_or_create_key_new(self) -> None:
|
||||
"""Test that new key is created when none exists."""
|
||||
mock_key = Fernet.generate_key()
|
||||
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", return_value=None) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=True) as mock_atomic_create,
|
||||
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=mock_key) as mock_generate,
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, mock_key)
|
||||
mock_read.assert_called_with("secret.key")
|
||||
mock_generate.assert_called_once()
|
||||
mock_atomic_create.assert_called_once_with("secret.key", mock_key)
|
||||
|
||||
def test_get_or_create_key_race_condition(self) -> None:
|
||||
"""Test that another process's key is used when atomic create fails."""
|
||||
our_key = Fernet.generate_key()
|
||||
their_key = Fernet.generate_key()
|
||||
|
||||
with (
|
||||
patch.object(self.token_manager, "_read_secure_file", side_effect=[None, their_key]) as mock_read,
|
||||
patch.object(self.token_manager, "_atomic_create_secure_file", return_value=False) as mock_atomic_create,
|
||||
patch("crewai_cli.shared.token_manager.Fernet.generate_key", return_value=our_key),
|
||||
):
|
||||
result = self.token_manager._get_or_create_key()
|
||||
|
||||
self.assertEqual(result, their_key)
|
||||
self.assertEqual(mock_read.call_count, 2)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._atomic_write_secure_file")
|
||||
def test_save_tokens(
|
||||
self, mock_write: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test saving tokens encrypts and writes atomically."""
|
||||
access_token = "test_token"
|
||||
expires_at = int((datetime.now() + timedelta(seconds=3600)).timestamp())
|
||||
|
||||
self.token_manager.save_tokens(access_token, expires_at)
|
||||
|
||||
mock_write.assert_called_once()
|
||||
args = mock_write.call_args[0]
|
||||
self.assertEqual(args[0], "tokens.enc")
|
||||
decrypted_data = self.token_manager.fernet.decrypt(args[1])
|
||||
data = json.loads(decrypted_data)
|
||||
self.assertEqual(data["access_token"], access_token)
|
||||
expiration = datetime.fromisoformat(data["expiration"])
|
||||
self.assertEqual(expiration, datetime.fromtimestamp(expires_at))
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_valid(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test getting a valid non-expired token."""
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() + timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertEqual(result, access_token)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_expired(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that expired token returns None."""
|
||||
access_token = "test_token"
|
||||
expiration = (datetime.now() - timedelta(hours=1)).isoformat()
|
||||
data = {"access_token": access_token, "expiration": expiration}
|
||||
encrypted_data = self.token_manager.fernet.encrypt(json.dumps(data).encode())
|
||||
mock_read.return_value = encrypted_data
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._read_secure_file")
|
||||
def test_get_token_not_found(
|
||||
self, mock_read: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that missing token file returns None."""
|
||||
mock_read.return_value = None
|
||||
|
||||
result = self.token_manager.get_token()
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._delete_secure_file")
|
||||
def test_clear_tokens(
|
||||
self, mock_delete: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test clearing tokens deletes the token file."""
|
||||
self.token_manager.clear_tokens()
|
||||
|
||||
mock_delete.assert_called_once_with("tokens.enc")
|
||||
|
||||
|
||||
class TestAtomicFileOperations(unittest.TestCase):
|
||||
"""Test atomic file operations directly."""
|
||||
|
||||
def setUp(self) -> None:
|
||||
"""Set up test fixtures with temp directory."""
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.original_get_path = TokenManager._get_secure_storage_path
|
||||
|
||||
# Patch to use temp directory
|
||||
def mock_get_path() -> Path:
|
||||
return Path(self.temp_dir)
|
||||
|
||||
TokenManager._get_secure_storage_path = staticmethod(mock_get_path)
|
||||
|
||||
def tearDown(self) -> None:
|
||||
"""Clean up temp directory."""
|
||||
TokenManager._get_secure_storage_path = staticmethod(self.original_get_path)
|
||||
import shutil
|
||||
shutil.rmtree(self.temp_dir, ignore_errors=True)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic create succeeds for new file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
result = tm._atomic_create_secure_file("test.txt", b"content")
|
||||
|
||||
self.assertTrue(result)
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_create_existing_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic create fails for existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Create file first
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"original")
|
||||
|
||||
result = tm._atomic_create_secure_file("test.txt", b"new content")
|
||||
|
||||
self.assertFalse(result)
|
||||
self.assertEqual(file_path.read_bytes(), b"original")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_new_file(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic write creates new file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"content")
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
self.assertTrue(file_path.exists())
|
||||
self.assertEqual(file_path.read_bytes(), b"content")
|
||||
self.assertEqual(file_path.stat().st_mode & 0o777, 0o600)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_overwrites(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test atomic write overwrites existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"original")
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"new content")
|
||||
|
||||
self.assertEqual(file_path.read_bytes(), b"new content")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_atomic_write_no_temp_file_on_success(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test that temp file is cleaned up after successful write."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"content")
|
||||
|
||||
# Check no temp files remain
|
||||
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
|
||||
self.assertEqual(len(temp_files), 0)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test reading existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
result = tm._read_secure_file("test.txt")
|
||||
|
||||
self.assertEqual(result, b"content")
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_read_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test reading non-existent file returns None."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
result = tm._read_secure_file("nonexistent.txt")
|
||||
|
||||
self.assertIsNone(result)
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test deleting existing file."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"content")
|
||||
|
||||
tm._delete_secure_file("test.txt")
|
||||
|
||||
self.assertFalse(file_path.exists())
|
||||
|
||||
@patch("crewai_cli.shared.token_manager.TokenManager._get_or_create_key")
|
||||
def test_delete_secure_file_not_exists(
|
||||
self, mock_get_key: unittest.mock.MagicMock
|
||||
) -> None:
|
||||
"""Test deleting non-existent file doesn't raise."""
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Should not raise
|
||||
tm._delete_secure_file("nonexistent.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
89
lib/cli/tests/test_train_crew.py
Normal file
89
lib/cli/tests/test_train_crew.py
Normal file
@@ -0,0 +1,89 @@
|
||||
import subprocess
|
||||
from unittest import mock
|
||||
|
||||
from crewai_cli.train_crew import train_crew
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.train_crew.subprocess.run")
|
||||
def test_train_crew_positive_iterations(mock_subprocess_run):
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.return_value = subprocess.CompletedProcess(
|
||||
args=["uv", "run", "train", str(n_iterations)],
|
||||
returncode=0,
|
||||
stdout="Success",
|
||||
stderr="",
|
||||
)
|
||||
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.train_crew.click")
|
||||
def test_train_crew_zero_iterations(click):
|
||||
train_crew(0, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.train_crew.click")
|
||||
def test_train_crew_negative_iterations(click):
|
||||
train_crew(-2, "trained_agents_data.pkl")
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: The number of iterations must be a positive integer.",
|
||||
err=True,
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.train_crew.click")
|
||||
@mock.patch("crewai_cli.train_crew.subprocess.run")
|
||||
def test_train_crew_called_process_error(mock_subprocess_run, click):
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = subprocess.CalledProcessError(
|
||||
returncode=1,
|
||||
cmd=["uv", "run", "train", str(n_iterations)],
|
||||
output="Error",
|
||||
stderr="Some error occurred",
|
||||
)
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_has_calls(
|
||||
[
|
||||
mock.call.echo(
|
||||
"An error occurred while training the crew: Command '['uv', 'run', 'train', '5']' returned non-zero exit status 1.",
|
||||
err=True,
|
||||
),
|
||||
mock.call.echo("Error", err=True),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@mock.patch("crewai_cli.train_crew.click")
|
||||
@mock.patch("crewai_cli.train_crew.subprocess.run")
|
||||
def test_train_crew_unexpected_exception(mock_subprocess_run, click):
|
||||
n_iterations = 5
|
||||
mock_subprocess_run.side_effect = Exception("Unexpected error")
|
||||
train_crew(n_iterations, "trained_agents_data.pkl")
|
||||
|
||||
mock_subprocess_run.assert_called_once_with(
|
||||
["uv", "run", "train", str(n_iterations), "trained_agents_data.pkl"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
click.echo.assert_called_once_with(
|
||||
"An unexpected error occurred: Unexpected error", err=True
|
||||
)
|
||||
146
lib/cli/tests/test_utils.py
Normal file
146
lib/cli/tests/test_utils.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import os
|
||||
import shutil
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from crewai_cli import utils
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_tree():
|
||||
root_dir = tempfile.mkdtemp()
|
||||
|
||||
create_file(os.path.join(root_dir, "file1.txt"), "Hello, world!")
|
||||
create_file(os.path.join(root_dir, "file2.txt"), "Another file")
|
||||
os.mkdir(os.path.join(root_dir, "empty_dir"))
|
||||
nested_dir = os.path.join(root_dir, "nested_dir")
|
||||
os.mkdir(nested_dir)
|
||||
create_file(os.path.join(nested_dir, "nested_file.txt"), "Nested content")
|
||||
|
||||
yield root_dir
|
||||
|
||||
shutil.rmtree(root_dir)
|
||||
|
||||
|
||||
def create_file(path, content):
|
||||
with open(path, "w") as f:
|
||||
f.write(content)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "world", "universe")
|
||||
with open(os.path.join(temp_tree, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, universe!"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_file_name(temp_tree):
|
||||
old_path = os.path.join(temp_tree, "file2.txt")
|
||||
new_path = os.path.join(temp_tree, "file2_renamed.txt")
|
||||
os.rename(old_path, new_path)
|
||||
utils.tree_find_and_replace(temp_tree, "renamed", "modified")
|
||||
assert os.path.exists(os.path.join(temp_tree, "file2_modified.txt"))
|
||||
assert not os.path.exists(new_path)
|
||||
|
||||
|
||||
def test_tree_find_and_replace_directory_name(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "empty", "renamed")
|
||||
assert os.path.exists(os.path.join(temp_tree, "renamed_dir"))
|
||||
assert not os.path.exists(os.path.join(temp_tree, "empty_dir"))
|
||||
|
||||
|
||||
def test_tree_find_and_replace_nested_content(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "Nested", "Updated")
|
||||
with open(os.path.join(temp_tree, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Updated content"
|
||||
|
||||
|
||||
def test_tree_find_and_replace_no_matches(temp_tree):
|
||||
utils.tree_find_and_replace(temp_tree, "nonexistent", "replacement")
|
||||
assert set(os.listdir(temp_tree)) == {
|
||||
"file1.txt",
|
||||
"file2.txt",
|
||||
"empty_dir",
|
||||
"nested_dir",
|
||||
}
|
||||
|
||||
|
||||
def test_tree_copy_full_structure(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert set(os.listdir(dest_dir)) == set(os.listdir(temp_tree))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file2.txt"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "empty_dir"))
|
||||
assert os.path.isdir(os.path.join(dest_dir, "nested_dir"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "nested_dir", "nested_file.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_preserve_content(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
with open(os.path.join(dest_dir, "file1.txt"), "r") as f:
|
||||
assert f.read() == "Hello, world!"
|
||||
with open(os.path.join(dest_dir, "nested_dir", "nested_file.txt"), "r") as f:
|
||||
assert f.read() == "Nested content"
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
def test_tree_copy_to_existing_directory(temp_tree):
|
||||
dest_dir = tempfile.mkdtemp()
|
||||
try:
|
||||
create_file(os.path.join(dest_dir, "existing_file.txt"), "I was here first")
|
||||
utils.tree_copy(temp_tree, dest_dir)
|
||||
assert os.path.isfile(os.path.join(dest_dir, "existing_file.txt"))
|
||||
assert os.path.isfile(os.path.join(dest_dir, "file1.txt"))
|
||||
finally:
|
||||
shutil.rmtree(dest_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_project_dir():
|
||||
"""Create a temporary directory for testing tool extraction."""
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
yield Path(temp_dir)
|
||||
|
||||
|
||||
def create_init_file(directory, content):
|
||||
return create_file(directory / "__init__.py", content)
|
||||
|
||||
|
||||
def test_extract_available_exports_empty_project(temp_project_dir, capsys):
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "No valid tools were exposed in your __init__.py file" in captured.out
|
||||
|
||||
|
||||
def test_extract_available_exports_no_init_file(temp_project_dir, capsys):
|
||||
(temp_project_dir / "some_file.py").write_text("print('hello')")
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "No valid tools were exposed in your __init__.py file" in captured.out
|
||||
|
||||
|
||||
def test_extract_available_exports_empty_init_file(temp_project_dir, capsys):
|
||||
create_init_file(temp_project_dir, "")
|
||||
with pytest.raises(SystemExit):
|
||||
utils.extract_available_exports(dir_path=temp_project_dir)
|
||||
captured = capsys.readouterr()
|
||||
|
||||
assert "Warning: No __all__ defined in" in captured.out
|
||||
|
||||
|
||||
# Tests for extract_available_exports with crewai.tools (BaseTool, @tool)
|
||||
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
|
||||
|
||||
# Tests for get_crews, get_flows, fetch_crews, is_valid_tool
|
||||
# remain in lib/crewai/tests/cli/test_utils.py as they require the crewai core package.
|
||||
372
lib/cli/tests/test_version.py
Normal file
372
lib/cli/tests/test_version.py
Normal file
@@ -0,0 +1,372 @@
|
||||
"""Test for version management."""
|
||||
|
||||
import json
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai_cli.version import get_crewai_version as _get_ver
|
||||
from crewai_cli.version import (
|
||||
_find_latest_non_yanked_version,
|
||||
_get_cache_file,
|
||||
_is_cache_valid,
|
||||
_is_version_yanked,
|
||||
get_crewai_version,
|
||||
get_latest_version_from_pypi,
|
||||
is_current_version_yanked,
|
||||
is_newer_version_available,
|
||||
)
|
||||
|
||||
|
||||
def test_dynamic_versioning_consistency() -> None:
|
||||
"""Test that dynamic versioning provides consistent version across all access methods."""
|
||||
cli_version = get_crewai_version()
|
||||
package_version = _get_ver()
|
||||
|
||||
assert cli_version == package_version
|
||||
|
||||
assert package_version is not None
|
||||
assert len(package_version.strip()) > 0
|
||||
|
||||
|
||||
class TestVersionChecking:
|
||||
"""Test version checking utilities."""
|
||||
|
||||
def test_get_crewai_version(self) -> None:
|
||||
"""Test getting current crewai version."""
|
||||
version = get_crewai_version()
|
||||
assert isinstance(version, str)
|
||||
assert len(version) > 0
|
||||
|
||||
def test_get_cache_file(self) -> None:
|
||||
"""Test cache file path generation."""
|
||||
cache_file = _get_cache_file()
|
||||
assert isinstance(cache_file, Path)
|
||||
assert cache_file.name == "version_cache.json"
|
||||
|
||||
def test_is_cache_valid_with_fresh_cache(self) -> None:
|
||||
"""Test cache validation with fresh cache."""
|
||||
cache_data = {"timestamp": datetime.now().isoformat(), "version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is True
|
||||
|
||||
def test_is_cache_valid_with_stale_cache(self) -> None:
|
||||
"""Test cache validation with stale cache."""
|
||||
old_time = datetime.now() - timedelta(hours=25)
|
||||
cache_data = {"timestamp": old_time.isoformat(), "version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is False
|
||||
|
||||
def test_is_cache_valid_with_missing_timestamp(self) -> None:
|
||||
"""Test cache validation with missing timestamp."""
|
||||
cache_data = {"version": "1.0.0"}
|
||||
assert _is_cache_valid(cache_data) is False
|
||||
|
||||
@patch("crewai_cli.version.Path.exists")
|
||||
@patch("crewai_cli.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_success(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
"""Test successful PyPI version fetch uses releases data."""
|
||||
mock_exists.return_value = False
|
||||
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": False}],
|
||||
"2.1.0": [{"yanked": True, "yanked_reason": "bad release"}],
|
||||
}
|
||||
mock_response = MagicMock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{"info": {"version": "2.1.0"}, "releases": releases}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.Path.exists")
|
||||
@patch("crewai_cli.version.request.urlopen")
|
||||
def test_get_latest_version_from_pypi_failure(
|
||||
self, mock_urlopen: MagicMock, mock_exists: MagicMock
|
||||
) -> None:
|
||||
"""Test PyPI version fetch failure."""
|
||||
from urllib.error import URLError
|
||||
|
||||
mock_exists.return_value = False
|
||||
|
||||
mock_urlopen.side_effect = URLError("Network error")
|
||||
|
||||
version = get_latest_version_from_pypi()
|
||||
assert version is None
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_true(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when newer version is available."""
|
||||
mock_current.return_value = "1.0.0"
|
||||
mock_latest.return_value = "2.0.0"
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is True
|
||||
assert current == "1.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_false(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when no newer version is available."""
|
||||
mock_current.return_value = "2.0.0"
|
||||
mock_latest.return_value = "2.0.0"
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is False
|
||||
assert current == "2.0.0"
|
||||
assert latest == "2.0.0"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
def test_is_newer_version_available_with_none_latest(
|
||||
self, mock_latest: MagicMock, mock_current: MagicMock
|
||||
) -> None:
|
||||
"""Test when PyPI fetch fails."""
|
||||
mock_current.return_value = "1.0.0"
|
||||
mock_latest.return_value = None
|
||||
|
||||
is_newer, current, latest = is_newer_version_available()
|
||||
assert is_newer is False
|
||||
assert current == "1.0.0"
|
||||
assert latest is None
|
||||
|
||||
|
||||
class TestFindLatestNonYankedVersion:
|
||||
"""Test _find_latest_non_yanked_version helper."""
|
||||
|
||||
def test_skips_yanked_versions(self) -> None:
|
||||
"""Test that yanked versions are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_returns_highest_non_yanked(self) -> None:
|
||||
"""Test that the highest non-yanked version is returned."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"1.5.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.5.0"
|
||||
|
||||
def test_returns_none_when_all_yanked(self) -> None:
|
||||
"""Test that None is returned when all versions are yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True}],
|
||||
"2.0.0": [{"yanked": True}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) is None
|
||||
|
||||
def test_skips_prerelease_versions(self) -> None:
|
||||
"""Test that pre-release versions are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0a1": [{"yanked": False}],
|
||||
"2.0.0rc1": [{"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_skips_versions_with_empty_files(self) -> None:
|
||||
"""Test that versions with no files are skipped."""
|
||||
releases: dict[str, list[dict[str, bool]]] = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_handles_invalid_version_strings(self) -> None:
|
||||
"""Test that invalid version strings are skipped."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"not-a-version": [{"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "1.0.0"
|
||||
|
||||
def test_partially_yanked_files_not_considered_yanked(self) -> None:
|
||||
"""Test that a version with some non-yanked files is not yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": False}],
|
||||
"2.0.0": [{"yanked": True}, {"yanked": False}],
|
||||
}
|
||||
assert _find_latest_non_yanked_version(releases) == "2.0.0"
|
||||
|
||||
|
||||
class TestIsVersionYanked:
|
||||
"""Test _is_version_yanked helper."""
|
||||
|
||||
def test_non_yanked_version(self) -> None:
|
||||
"""Test a non-yanked version returns False."""
|
||||
releases = {"1.0.0": [{"yanked": False}]}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_yanked_version_with_reason(self) -> None:
|
||||
"""Test a yanked version returns True with reason."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True, "yanked_reason": "critical bug"}],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == "critical bug"
|
||||
|
||||
def test_yanked_version_without_reason(self) -> None:
|
||||
"""Test a yanked version returns True with empty reason."""
|
||||
releases = {"1.0.0": [{"yanked": True}]}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == ""
|
||||
|
||||
def test_unknown_version(self) -> None:
|
||||
"""Test an unknown version returns False."""
|
||||
releases = {"1.0.0": [{"yanked": False}]}
|
||||
is_yanked, reason = _is_version_yanked("9.9.9", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_partially_yanked_files(self) -> None:
|
||||
"""Test a version with mixed yanked/non-yanked files is not yanked."""
|
||||
releases = {
|
||||
"1.0.0": [{"yanked": True}, {"yanked": False}],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
def test_multiple_yanked_files_picks_first_reason(self) -> None:
|
||||
"""Test that the first available reason is returned."""
|
||||
releases = {
|
||||
"1.0.0": [
|
||||
{"yanked": True, "yanked_reason": ""},
|
||||
{"yanked": True, "yanked_reason": "second reason"},
|
||||
],
|
||||
}
|
||||
is_yanked, reason = _is_version_yanked("1.0.0", releases)
|
||||
assert is_yanked is True
|
||||
assert reason == "second reason"
|
||||
|
||||
|
||||
class TestIsCurrentVersionYanked:
|
||||
"""Test is_current_version_yanked public function."""
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_reads_from_valid_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test reading yanked status from a valid cache."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": True,
|
||||
"current_version_yanked_reason": "bad release",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is True
|
||||
assert reason == "bad release"
|
||||
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_not_yanked_from_cache(
|
||||
self, mock_cache_file: MagicMock, mock_version: MagicMock, tmp_path: Path
|
||||
) -> None:
|
||||
"""Test non-yanked status from a valid cache."""
|
||||
mock_version.return_value = "2.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "2.0.0",
|
||||
"current_version_yanked": False,
|
||||
"current_version_yanked_reason": "",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_triggers_fetch_on_stale_cache(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
mock_version: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test that a stale cache triggers a re-fetch."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
old_time = datetime.now() - timedelta(hours=25)
|
||||
cache_data = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": old_time.isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": True,
|
||||
"current_version_yanked_reason": "old reason",
|
||||
}
|
||||
cache_file.write_text(json.dumps(cache_data))
|
||||
mock_cache_file.return_value = cache_file
|
||||
|
||||
fresh_cache = {
|
||||
"version": "2.0.0",
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"current_version": "1.0.0",
|
||||
"current_version_yanked": False,
|
||||
"current_version_yanked_reason": "",
|
||||
}
|
||||
|
||||
def write_fresh_cache() -> str:
|
||||
cache_file.write_text(json.dumps(fresh_cache))
|
||||
return "2.0.0"
|
||||
|
||||
mock_fetch.side_effect = lambda: write_fresh_cache()
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
mock_fetch.assert_called_once()
|
||||
|
||||
@patch("crewai_cli.version.get_latest_version_from_pypi")
|
||||
@patch("crewai_cli.version.get_crewai_version")
|
||||
@patch("crewai_cli.version._get_cache_file")
|
||||
def test_returns_false_on_fetch_failure(
|
||||
self,
|
||||
mock_cache_file: MagicMock,
|
||||
mock_version: MagicMock,
|
||||
mock_fetch: MagicMock,
|
||||
tmp_path: Path,
|
||||
) -> None:
|
||||
"""Test that fetch failure returns not yanked."""
|
||||
mock_version.return_value = "1.0.0"
|
||||
cache_file = tmp_path / "version_cache.json"
|
||||
mock_cache_file.return_value = cache_file
|
||||
mock_fetch.return_value = None
|
||||
|
||||
is_yanked, reason = is_current_version_yanked()
|
||||
assert is_yanked is False
|
||||
assert reason == ""
|
||||
|
||||
|
||||
|
||||
# TestConsoleFormatterVersionCheck tests remain in lib/crewai/tests/cli/test_version.py
|
||||
# as they depend on crewai.events.utils.console_formatter (core package).
|
||||
0
lib/cli/tests/tools/__init__.py
Normal file
0
lib/cli/tests/tools/__init__.py
Normal file
390
lib/cli/tests/tools/test_main.py
Normal file
390
lib/cli/tests/tools/test_main.py
Normal file
@@ -0,0 +1,390 @@
|
||||
import os
|
||||
import tempfile
|
||||
import unittest
|
||||
import unittest.mock
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from crewai_cli.shared.token_manager import TokenManager
|
||||
from crewai_cli.tools.main import ToolCommand
|
||||
from pytest import raises
|
||||
|
||||
|
||||
@contextmanager
|
||||
def in_temp_dir():
|
||||
original_dir = os.getcwd()
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
os.chdir(temp_dir)
|
||||
try:
|
||||
yield temp_dir
|
||||
finally:
|
||||
os.chdir(original_dir)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_command():
|
||||
# Create a temporary directory for each test to avoid token storage conflicts
|
||||
with tempfile.TemporaryDirectory() as temp_dir:
|
||||
# Mock the secure storage path to use the temp directory
|
||||
with patch.object(
|
||||
TokenManager, "_get_secure_storage_path", return_value=Path(temp_dir)
|
||||
):
|
||||
TokenManager().save_tokens(
|
||||
"test-token", (datetime.now() + timedelta(seconds=36000)).timestamp()
|
||||
)
|
||||
tool_command = ToolCommand()
|
||||
with patch.object(tool_command, "login"):
|
||||
yield tool_command
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
def test_create_success(mock_subprocess, capsys, tool_command):
|
||||
with in_temp_dir():
|
||||
tool_command.create("test-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Creating custom tool test_tool..." in output
|
||||
|
||||
assert os.path.isdir("test_tool")
|
||||
assert os.path.isfile(os.path.join("test_tool", "README.md"))
|
||||
assert os.path.isfile(os.path.join("test_tool", "pyproject.toml"))
|
||||
assert os.path.isfile(
|
||||
os.path.join("test_tool", "src", "test_tool", "__init__.py")
|
||||
)
|
||||
assert os.path.isfile(os.path.join("test_tool", "src", "test_tool", "tool.py"))
|
||||
|
||||
with open(os.path.join("test_tool", "src", "test_tool", "tool.py"), "r") as f:
|
||||
content = f.read()
|
||||
assert "class TestTool" in content
|
||||
|
||||
mock_subprocess.assert_called_once_with(["git", "init"], check=True)
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.plus_api.PlusAPI.get_tool")
|
||||
@patch("crewai_cli.tools.main.ToolCommand._print_current_organization")
|
||||
def test_install_success(
|
||||
mock_print_org, mock_get, mock_subprocess_run, capsys, tool_command
|
||||
):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {"handle": "sample-repo", "url": "https://example.com/repo"},
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"uv",
|
||||
"add",
|
||||
"--index",
|
||||
"sample-repo=https://example.com/repo",
|
||||
"sample-tool",
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
# Verify _print_current_organization was called
|
||||
mock_print_org.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_success_from_pypi(mock_get, mock_subprocess_run, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"handle": "sample-tool",
|
||||
"repository": {"handle": "sample-repo", "url": "https://example.com/repo"},
|
||||
"source": "pypi",
|
||||
}
|
||||
mock_get.return_value = mock_get_response
|
||||
mock_subprocess_run.return_value = MagicMock(stderr=None)
|
||||
|
||||
tool_command.install("sample-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Successfully installed sample-tool" in output
|
||||
|
||||
mock_get.assert_has_calls([mock.call("sample-tool"), mock.call().json()])
|
||||
mock_subprocess_run.assert_any_call(
|
||||
[
|
||||
"uv",
|
||||
"add",
|
||||
"sample-tool",
|
||||
],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True,
|
||||
env=unittest.mock.ANY,
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_tool_not_found(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.install("non-existent-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "No tool found with this name" in output
|
||||
|
||||
mock_get.assert_called_once_with("non-existent-tool")
|
||||
|
||||
|
||||
@patch("crewai_cli.plus_api.PlusAPI.get_tool")
|
||||
def test_install_api_error(mock_get, capsys, tool_command):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 500
|
||||
mock_get.return_value = mock_get_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.install("error-tool")
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to get tool details" in output
|
||||
|
||||
mock_get.assert_called_once_with("error-tool")
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.git.Repository.fetch")
|
||||
@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
def test_publish_when_not_in_sync(mock_is_synced, mock_fetch, capsys, tool_command):
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Local changes need to be resolved before publishing" in output
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai_cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai_cli.tools.main.git.Repository.fetch")
|
||||
@patch("crewai_cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=False)
|
||||
@patch(
|
||||
"crewai_cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch("crewai_cli.tools.main.ToolCommand._print_current_organization")
|
||||
def test_publish_when_not_in_sync_and_force(
|
||||
mock_print_org,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_fetch,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command.publish(is_public=True, force=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
mock_get_project_version.assert_called_with(require=True)
|
||||
mock_get_project_description.assert_called_with(require=False)
|
||||
mock_subprocess_run.assert_called_with(
|
||||
["uv", "build", "--sdist", "--out-dir", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
mock_open.assert_called_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
)
|
||||
mock_print_org.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai_cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai_cli.tools.main.git.Repository.fetch")
|
||||
@patch("crewai_cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai_cli.tools.main.git.Repository.is_synced", return_value=True)
|
||||
@patch(
|
||||
"crewai_cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
def test_publish_success(
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_fetch,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
mock_get_project_name.assert_called_with(require=True)
|
||||
mock_get_project_version.assert_called_with(require=True)
|
||||
mock_get_project_description.assert_called_with(require=False)
|
||||
mock_subprocess_run.assert_called_with(
|
||||
["uv", "build", "--sdist", "--out-dir", unittest.mock.ANY],
|
||||
check=True,
|
||||
capture_output=False,
|
||||
)
|
||||
mock_open.assert_called_with(unittest.mock.ANY, "rb")
|
||||
mock_publish.assert_called_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai_cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai_cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch(
|
||||
"crewai_cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
def test_publish_failure(
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 422
|
||||
mock_publish_response.json.return_value = {"name": ["is already taken"]}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Failed to complete operation" in output
|
||||
assert "Name is already taken" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai_cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai_cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai_cli.tools.main.subprocess.run")
|
||||
@patch("crewai_cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai_cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai_cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch(
|
||||
"crewai_cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
def test_publish_api_error(
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 500
|
||||
mock_response.json.return_value = {"error": "Internal Server Error"}
|
||||
mock_response.is_success = False
|
||||
mock_publish.return_value = mock_response
|
||||
|
||||
with raises(SystemExit):
|
||||
tool_command.publish(is_public=True)
|
||||
output = capsys.readouterr().out
|
||||
assert "Request to Enterprise API failed" in output
|
||||
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.Settings")
|
||||
def test_print_current_organization_with_org(mock_settings, capsys, tool_command):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_uuid = "test-org-uuid"
|
||||
mock_settings_instance.org_name = "Test Organization"
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
tool_command._print_current_organization()
|
||||
output = capsys.readouterr().out
|
||||
assert "Current organization: Test Organization (test-org-uuid)" in output
|
||||
|
||||
|
||||
@patch("crewai_cli.tools.main.Settings")
|
||||
def test_print_current_organization_without_org(mock_settings, capsys, tool_command):
|
||||
mock_settings_instance = MagicMock()
|
||||
mock_settings_instance.org_uuid = None
|
||||
mock_settings_instance.org_name = None
|
||||
mock_settings.return_value = mock_settings_instance
|
||||
tool_command._print_current_organization()
|
||||
output = capsys.readouterr().out
|
||||
assert "No organization currently set" in output
|
||||
assert "org switch <org_id>" in output
|
||||
170
lib/cli/tests/triggers/test_main.py
Normal file
170
lib/cli/tests/triggers/test_main.py
Normal file
@@ -0,0 +1,170 @@
|
||||
import json
|
||||
import subprocess
|
||||
import unittest
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import httpx
|
||||
from crewai_cli.triggers.main import TriggersCommand
|
||||
|
||||
|
||||
class TestTriggersCommand(unittest.TestCase):
|
||||
@patch("crewai_cli.command.get_auth_token")
|
||||
@patch("crewai_cli.command.PlusAPI")
|
||||
def setUp(self, mock_plus_api, mock_get_auth_token):
|
||||
self.mock_get_auth_token = mock_get_auth_token
|
||||
self.mock_plus_api = mock_plus_api
|
||||
|
||||
self.mock_get_auth_token.return_value = "test_token"
|
||||
|
||||
self.triggers_command = TriggersCommand()
|
||||
self.mock_client = self.triggers_command.plus_api_client
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_list_triggers_success(self, mock_console_print):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"apps": [
|
||||
{
|
||||
"name": "Test App",
|
||||
"slug": "test-app",
|
||||
"description": "A test application",
|
||||
"is_connected": True,
|
||||
"triggers": [
|
||||
{
|
||||
"name": "Test Trigger",
|
||||
"slug": "test-trigger",
|
||||
"description": "A test trigger"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
self.mock_client.get_triggers.return_value = mock_response
|
||||
|
||||
self.triggers_command.list_triggers()
|
||||
|
||||
self.mock_client.get_triggers.assert_called_once()
|
||||
mock_console_print.assert_any_call("[bold blue]Fetching available triggers...[/bold blue]")
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_list_triggers_no_apps(self, mock_console_print):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {"apps": []}
|
||||
self.mock_client.get_triggers.return_value = mock_response
|
||||
|
||||
self.triggers_command.list_triggers()
|
||||
|
||||
mock_console_print.assert_any_call("[yellow]No triggers found.[/yellow]")
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_list_triggers_api_error(self, mock_console_print):
|
||||
self.mock_client.get_triggers.side_effect = Exception("API Error")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command.list_triggers()
|
||||
|
||||
mock_console_print.assert_any_call("[bold red]Error fetching triggers: API Error[/bold red]")
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_invalid_format(self, mock_console_print):
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command.execute_with_trigger("invalid-format")
|
||||
|
||||
mock_console_print.assert_called_with(
|
||||
"[bold red]Error: Trigger must be in format 'app_slug/trigger_slug'[/bold red]"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
@patch.object(TriggersCommand, "_run_crew_with_payload")
|
||||
def test_execute_with_trigger_success(self, mock_run_crew, mock_console_print):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 200
|
||||
mock_response.ok = True
|
||||
mock_response.json.return_value = {
|
||||
"sample_payload": {"key": "value", "data": "test"}
|
||||
}
|
||||
self.mock_client.get_trigger_payload.return_value = mock_response
|
||||
|
||||
self.triggers_command.execute_with_trigger("test-app/test-trigger")
|
||||
|
||||
self.mock_client.get_trigger_payload.assert_called_once_with("test-app", "test-trigger")
|
||||
mock_run_crew.assert_called_once_with({"key": "value", "data": "test"})
|
||||
mock_console_print.assert_any_call(
|
||||
"[bold blue]Fetching trigger payload for test-app/test-trigger...[/bold blue]"
|
||||
)
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_not_found(self, mock_console_print):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {"error": "Trigger not found"}
|
||||
self.mock_client.get_trigger_payload.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command.execute_with_trigger("test-app/nonexistent-trigger")
|
||||
|
||||
mock_console_print.assert_any_call("[bold red]Error: Trigger not found[/bold red]")
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_api_error(self, mock_console_print):
|
||||
self.mock_client.get_trigger_payload.side_effect = Exception("API Error")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command.execute_with_trigger("test-app/test-trigger")
|
||||
|
||||
mock_console_print.assert_any_call(
|
||||
"[bold red]Error executing crew with trigger: API Error[/bold red]"
|
||||
)
|
||||
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_run_crew_with_payload_success(self, mock_subprocess):
|
||||
payload = {"key": "value", "data": "test"}
|
||||
mock_subprocess.return_value = None
|
||||
|
||||
self.triggers_command._run_crew_with_payload(payload)
|
||||
|
||||
mock_subprocess.assert_called_once_with(
|
||||
["uv", "run", "run_with_trigger", json.dumps(payload)],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_run_crew_with_payload_failure(self, mock_subprocess):
|
||||
payload = {"key": "value"}
|
||||
mock_subprocess.side_effect = subprocess.CalledProcessError(1, "uv")
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command._run_crew_with_payload(payload)
|
||||
|
||||
@patch("subprocess.run")
|
||||
def test_run_crew_with_payload_empty_payload(self, mock_subprocess):
|
||||
payload = {}
|
||||
mock_subprocess.return_value = None
|
||||
|
||||
self.triggers_command._run_crew_with_payload(payload)
|
||||
|
||||
mock_subprocess.assert_called_once_with(
|
||||
["uv", "run", "run_with_trigger", "{}"],
|
||||
capture_output=False,
|
||||
text=True,
|
||||
check=True
|
||||
)
|
||||
|
||||
@patch("crewai_cli.triggers.main.console.print")
|
||||
def test_execute_with_trigger_with_default_error_message(self, mock_console_print):
|
||||
mock_response = Mock(spec=httpx.Response)
|
||||
mock_response.status_code = 404
|
||||
mock_response.json.return_value = {}
|
||||
self.mock_client.get_trigger_payload.return_value = mock_response
|
||||
|
||||
with self.assertRaises(SystemExit):
|
||||
self.triggers_command.execute_with_trigger("test-app/test-trigger")
|
||||
|
||||
mock_console_print.assert_any_call("[bold red]Error: Trigger not found[/bold red]")
|
||||
Reference in New Issue
Block a user