diff --git a/pkg/internal/env/variables.go b/pkg/internal/env/variables.go new file mode 100644 index 00000000..59e71011 --- /dev/null +++ b/pkg/internal/env/variables.go @@ -0,0 +1,30 @@ +package env + +const ( + // env vars + LoginMethod = "AAD_LOGIN_METHOD" + KubeloginROPCUsername = "AAD_USER_PRINCIPAL_NAME" + KubeloginROPCPassword = "AAD_USER_PRINCIPAL_PASSWORD" + KubeloginClientID = "AAD_SERVICE_PRINCIPAL_CLIENT_ID" + KubeloginClientSecret = "AAD_SERVICE_PRINCIPAL_CLIENT_SECRET" + KubeloginClientCertificatePath = "AAD_SERVICE_PRINCIPAL_CLIENT_CERTIFICATE" + KubeloginClientCertificatePassword = "AAD_SERVICE_PRINCIPAL_CLIENT_CERTIFICATE_PASSWORD" + + // env vars used by Terraform + TerraformClientID = "ARM_CLIENT_ID" + TerraformClientSecret = "ARM_CLIENT_SECRET" + TerraformClientCertificatePath = "ARM_CLIENT_CERTIFICATE_PATH" + TerraformClientCertificatePassword = "ARM_CLIENT_CERTIFICATE_PASSWORD" + TerraformTenantID = "ARM_TENANT_ID" + + // env vars following azure sdk naming convention + AzureAuthorityHost = "AZURE_AUTHORITY_HOST" + AzureClientCertificatePassword = "AZURE_CLIENT_CERTIFICATE_PASSWORD" + AzureClientCertificatePath = "AZURE_CLIENT_CERTIFICATE_PATH" + AzureClientID = "AZURE_CLIENT_ID" + AzureClientSecret = "AZURE_CLIENT_SECRET" + AzureFederatedTokenFile = "AZURE_FEDERATED_TOKEN_FILE" + AzureTenantID = "AZURE_TENANT_ID" + AzureUsername = "AZURE_USERNAME" + AzurePassword = "AZURE_PASSWORD" +) diff --git a/pkg/internal/token/options.go b/pkg/internal/token/options.go index f2fab32c..ccc02a86 100644 --- a/pkg/internal/token/options.go +++ b/pkg/internal/token/options.go @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/Azure/kubelogin/pkg/internal/env" "github.com/spf13/pflag" "k8s.io/client-go/util/homedir" ) @@ -45,33 +46,6 @@ const ( AzureCLILogin = "azurecli" WorkloadIdentityLogin = "workloadidentity" manualTokenLogin = "manual_token" - - // env vars - loginMethod = "AAD_LOGIN_METHOD" - kubeloginROPCUsername = "AAD_USER_PRINCIPAL_NAME" - kubeloginROPCPassword = "AAD_USER_PRINCIPAL_PASSWORD" - kubeloginClientID = "AAD_SERVICE_PRINCIPAL_CLIENT_ID" - kubeloginClientSecret = "AAD_SERVICE_PRINCIPAL_CLIENT_SECRET" - kubeloginClientCertificatePath = "AAD_SERVICE_PRINCIPAL_CLIENT_CERTIFICATE" - kubeloginClientCertificatePassword = "AAD_SERVICE_PRINCIPAL_CLIENT_CERTIFICATE_PASSWORD" - - // env vars used by Terraform - terraformClientID = "ARM_CLIENT_ID" - terraformClientSecret = "ARM_CLIENT_SECRET" - terraformClientCertificatePath = "ARM_CLIENT_CERTIFICATE_PATH" - terraformClientCertificatePassword = "ARM_CLIENT_CERTIFICATE_PASSWORD" - terraformTenantID = "ARM_TENANT_ID" - - // env vars following azure sdk naming convention - azureAuthorityHost = "AZURE_AUTHORITY_HOST" - azureClientCertificatePassword = "AZURE_CLIENT_CERTIFICATE_PASSWORD" - azureClientCertificatePath = "AZURE_CLIENT_CERTIFICATE_PATH" - azureClientID = "AZURE_CLIENT_ID" - azureClientSecret = "AZURE_CLIENT_SECRET" - azureFederatedTokenFile = "AZURE_FEDERATED_TOKEN_FILE" - azureTenantID = "AZURE_TENANT_ID" - azureUsername = "AZURE_USERNAME" - azurePassword = "AZURE_PASSWORD" ) var ( @@ -97,27 +71,27 @@ func NewOptions() Options { func (o *Options) AddFlags(fs *pflag.FlagSet) { fs.StringVarP(&o.LoginMethod, "login", "l", o.LoginMethod, - fmt.Sprintf("Login method. Supported methods: %s. It may be specified in %s environment variable", GetSupportedLogins(), loginMethod)) + fmt.Sprintf("Login method. Supported methods: %s. It may be specified in %s environment variable", GetSupportedLogins(), env.LoginMethod)) fs.StringVar(&o.ClientID, "client-id", o.ClientID, - fmt.Sprintf("AAD client application ID. It may be specified in %s or %s environment variable", kubeloginClientID, azureClientID)) + fmt.Sprintf("AAD client application ID. It may be specified in %s or %s environment variable", env.KubeloginClientID, env.AzureClientID)) fs.StringVar(&o.ClientSecret, "client-secret", o.ClientSecret, - fmt.Sprintf("AAD client application secret. Used in spn login. It may be specified in %s or %s environment variable", kubeloginClientSecret, azureClientSecret)) + fmt.Sprintf("AAD client application secret. Used in spn login. It may be specified in %s or %s environment variable", env.KubeloginClientSecret, env.AzureClientSecret)) fs.StringVar(&o.ClientCert, "client-certificate", o.ClientCert, - fmt.Sprintf("AAD client cert in pfx. Used in spn login. It may be specified in %s or %s environment variable", kubeloginClientCertificatePath, azureClientCertificatePath)) + fmt.Sprintf("AAD client cert in pfx. Used in spn login. It may be specified in %s or %s environment variable", env.KubeloginClientCertificatePath, env.AzureClientCertificatePath)) fs.StringVar(&o.ClientCertPassword, "client-certificate-password", o.ClientCertPassword, - fmt.Sprintf("Password for AAD client cert. Used in spn login. It may be specified in %s or %s environment variable", kubeloginClientCertificatePassword, azureClientCertificatePassword)) + fmt.Sprintf("Password for AAD client cert. Used in spn login. It may be specified in %s or %s environment variable", env.KubeloginClientCertificatePassword, env.AzureClientCertificatePassword)) fs.StringVar(&o.Username, "username", o.Username, - fmt.Sprintf("user name for ropc login flow. It may be specified in %s or %s environment variable", kubeloginROPCUsername, azureUsername)) + fmt.Sprintf("user name for ropc login flow. It may be specified in %s or %s environment variable", env.KubeloginROPCUsername, env.AzureUsername)) fs.StringVar(&o.Password, "password", o.Password, - fmt.Sprintf("password for ropc login flow. It may be specified in %s or %s environment variable", kubeloginROPCPassword, azurePassword)) + fmt.Sprintf("password for ropc login flow. It may be specified in %s or %s environment variable", env.KubeloginROPCPassword, env.AzurePassword)) fs.StringVar(&o.IdentityResourceID, "identity-resource-id", o.IdentityResourceID, "Managed Identity resource id.") fs.StringVar(&o.ServerID, "server-id", o.ServerID, "AAD server application ID") fs.StringVar(&o.FederatedTokenFile, "federated-token-file", o.FederatedTokenFile, - fmt.Sprintf("Workload Identity federated token file. It may be specified in %s environment variable", azureFederatedTokenFile)) + fmt.Sprintf("Workload Identity federated token file. It may be specified in %s environment variable", env.AzureFederatedTokenFile)) fs.StringVar(&o.AuthorityHost, "authority-host", o.AuthorityHost, - fmt.Sprintf("Workload Identity authority host. It may be specified in %s environment variable", azureAuthorityHost)) + fmt.Sprintf("Workload Identity authority host. It may be specified in %s environment variable", env.AzureAuthorityHost)) fs.StringVar(&o.TokenCacheDir, "token-cache-dir", o.TokenCacheDir, "directory to cache token") - fs.StringVarP(&o.TenantID, "tenant-id", "t", o.TenantID, fmt.Sprintf("AAD tenant ID. It may be specified in %s environment variable", azureTenantID)) + fs.StringVarP(&o.TenantID, "tenant-id", "t", o.TenantID, fmt.Sprintf("AAD tenant ID. It may be specified in %s environment variable", env.AzureTenantID)) fs.StringVarP(&o.Environment, "environment", "e", o.Environment, "Azure environment name") fs.BoolVar(&o.IsLegacy, "legacy", o.IsLegacy, "set to true to get token with 'spn:' prefix in audience claim") fs.BoolVar(&o.UseAzureRMTerraformEnv, "use-azurerm-env-vars", o.UseAzureRMTerraformEnv, @@ -160,75 +134,75 @@ func (o *Options) UpdateFromEnv() { o.tokenCacheFile = getCacheFileName(o) if o.UseAzureRMTerraformEnv { - if v, ok := os.LookupEnv(terraformClientID); ok { + if v, ok := os.LookupEnv(env.TerraformClientID); ok { o.ClientID = v } - if v, ok := os.LookupEnv(terraformClientSecret); ok { + if v, ok := os.LookupEnv(env.TerraformClientSecret); ok { o.ClientSecret = v } - if v, ok := os.LookupEnv(terraformClientCertificatePath); ok { + if v, ok := os.LookupEnv(env.TerraformClientCertificatePath); ok { o.ClientCert = v } - if v, ok := os.LookupEnv(terraformClientCertificatePassword); ok { + if v, ok := os.LookupEnv(env.TerraformClientCertificatePassword); ok { o.ClientCertPassword = v } - if v, ok := os.LookupEnv(terraformTenantID); ok { + if v, ok := os.LookupEnv(env.TerraformTenantID); ok { o.TenantID = v } } else { - if v, ok := os.LookupEnv(kubeloginClientID); ok { + if v, ok := os.LookupEnv(env.KubeloginClientID); ok { o.ClientID = v } - if v, ok := os.LookupEnv(azureClientID); ok { + if v, ok := os.LookupEnv(env.AzureClientID); ok { o.ClientID = v } - if v, ok := os.LookupEnv(kubeloginClientSecret); ok { + if v, ok := os.LookupEnv(env.KubeloginClientSecret); ok { o.ClientSecret = v } - if v, ok := os.LookupEnv(azureClientSecret); ok { + if v, ok := os.LookupEnv(env.AzureClientSecret); ok { o.ClientSecret = v } - if v, ok := os.LookupEnv(kubeloginClientCertificatePath); ok { + if v, ok := os.LookupEnv(env.KubeloginClientCertificatePath); ok { o.ClientCert = v } - if v, ok := os.LookupEnv(azureClientCertificatePath); ok { + if v, ok := os.LookupEnv(env.AzureClientCertificatePath); ok { o.ClientCert = v } - if v, ok := os.LookupEnv(kubeloginClientCertificatePassword); ok { + if v, ok := os.LookupEnv(env.KubeloginClientCertificatePassword); ok { o.ClientCertPassword = v } - if v, ok := os.LookupEnv(azureClientCertificatePassword); ok { + if v, ok := os.LookupEnv(env.AzureClientCertificatePassword); ok { o.ClientCertPassword = v } - if v, ok := os.LookupEnv(azureTenantID); ok { + if v, ok := os.LookupEnv(env.AzureTenantID); ok { o.TenantID = v } } - if v, ok := os.LookupEnv(kubeloginROPCUsername); ok { + if v, ok := os.LookupEnv(env.KubeloginROPCUsername); ok { o.Username = v } - if v, ok := os.LookupEnv(azureUsername); ok { + if v, ok := os.LookupEnv(env.AzureUsername); ok { o.Username = v } - if v, ok := os.LookupEnv(kubeloginROPCPassword); ok { + if v, ok := os.LookupEnv(env.KubeloginROPCPassword); ok { o.Password = v } - if v, ok := os.LookupEnv(azurePassword); ok { + if v, ok := os.LookupEnv(env.AzurePassword); ok { o.Password = v } - if v, ok := os.LookupEnv(loginMethod); ok { + if v, ok := os.LookupEnv(env.LoginMethod); ok { o.LoginMethod = v } if o.LoginMethod == WorkloadIdentityLogin { - if v, ok := os.LookupEnv(azureClientID); ok { + if v, ok := os.LookupEnv(env.AzureClientID); ok { o.ClientID = v } - if v, ok := os.LookupEnv(azureFederatedTokenFile); ok { + if v, ok := os.LookupEnv(env.AzureFederatedTokenFile); ok { o.FederatedTokenFile = v } - if v, ok := os.LookupEnv(azureAuthorityHost); ok { + if v, ok := os.LookupEnv(env.AzureAuthorityHost); ok { o.AuthorityHost = v } } diff --git a/pkg/internal/token/options_test.go b/pkg/internal/token/options_test.go index bc6c127b..27e771c2 100644 --- a/pkg/internal/token/options_test.go +++ b/pkg/internal/token/options_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/Azure/kubelogin/pkg/internal/env" "github.com/Azure/kubelogin/pkg/internal/testutils" "github.com/google/go-cmp/cmp" "github.com/spf13/pflag" @@ -86,14 +87,14 @@ func TestOptionsWithEnvVars(t *testing.T) { { name: "setting env var using legacy env var format", envVarMap: map[string]string{ - kubeloginClientID: clientID, - kubeloginClientSecret: clientSecret, - kubeloginClientCertificatePath: certPath, - kubeloginClientCertificatePassword: certPassword, - kubeloginROPCUsername: username, - kubeloginROPCPassword: password, - azureTenantID: tenantID, - loginMethod: DeviceCodeLogin, + env.KubeloginClientID: clientID, + env.KubeloginClientSecret: clientSecret, + env.KubeloginClientCertificatePath: certPath, + env.KubeloginClientCertificatePassword: certPassword, + env.KubeloginROPCUsername: username, + env.KubeloginROPCPassword: password, + env.AzureTenantID: tenantID, + env.LoginMethod: DeviceCodeLogin, }, expected: Options{ ClientID: clientID, @@ -112,12 +113,12 @@ func TestOptionsWithEnvVars(t *testing.T) { name: "setting env var using terraform env var format", isTerraform: true, envVarMap: map[string]string{ - terraformClientID: clientID, - terraformClientSecret: clientSecret, - terraformClientCertificatePath: certPath, - terraformClientCertificatePassword: certPassword, - terraformTenantID: tenantID, - loginMethod: DeviceCodeLogin, + env.TerraformClientID: clientID, + env.TerraformClientSecret: clientSecret, + env.TerraformClientCertificatePath: certPath, + env.TerraformClientCertificatePassword: certPassword, + env.TerraformTenantID: tenantID, + env.LoginMethod: DeviceCodeLogin, }, expected: Options{ UseAzureRMTerraformEnv: true, @@ -134,16 +135,16 @@ func TestOptionsWithEnvVars(t *testing.T) { { name: "setting env var using azure sdk env var format", envVarMap: map[string]string{ - azureClientID: clientID, - azureClientSecret: clientSecret, - azureClientCertificatePath: certPath, - azureClientCertificatePassword: certPassword, - azureUsername: username, - azurePassword: password, - azureTenantID: tenantID, - loginMethod: WorkloadIdentityLogin, - azureFederatedTokenFile: tokenFile, - azureAuthorityHost: authorityHost, + env.AzureClientID: clientID, + env.AzureClientSecret: clientSecret, + env.AzureClientCertificatePath: certPath, + env.AzureClientCertificatePassword: certPassword, + env.AzureUsername: username, + env.AzurePassword: password, + env.AzureTenantID: tenantID, + env.LoginMethod: WorkloadIdentityLogin, + env.AzureFederatedTokenFile: tokenFile, + env.AzureAuthorityHost: authorityHost, }, expected: Options{ ClientID: clientID, diff --git a/pkg/token/options.go b/pkg/token/options.go new file mode 100644 index 00000000..15dc9c14 --- /dev/null +++ b/pkg/token/options.go @@ -0,0 +1,42 @@ +package token + +import "github.com/Azure/kubelogin/pkg/internal/token" + +// list of supported login methods for library consumers + +const ( + ServicePrincipalLogin = token.ServicePrincipalLogin + MSILogin = token.MSILogin + WorkloadIdentityLogin = token.WorkloadIdentityLogin +) + +// Options defines the options for getting token. +// This struct is a subset of internal/token.Options where its values are copied +// to internal type. See internal/token/options.go for details +type Options struct { + LoginMethod string + + // shared login settings + + Environment string + TenantID string + ServerID string + ClientID string + + // for ServicePrincipalLogin + + ClientSecret string + ClientCert string + ClientCertPassword string + IsPoPTokenEnabled bool + PoPTokenClaims string + + // for MSILogin + + IdentityResourceID string + + // for WorkloadIdentityLogin + + AuthorityHost string + FederatedTokenFile string +} diff --git a/pkg/token/options_ctor.go b/pkg/token/options_ctor.go new file mode 100644 index 00000000..f91499a3 --- /dev/null +++ b/pkg/token/options_ctor.go @@ -0,0 +1,57 @@ +package token + +import ( + "os" + + "github.com/Azure/kubelogin/pkg/internal/env" + "github.com/Azure/kubelogin/pkg/internal/token" +) + +// OptionsWithEnv loads options from environment variables. +func OptionsWithEnv() *Options { + // initial default values + rv := &Options{ + LoginMethod: os.Getenv(env.LoginMethod), + TenantID: os.Getenv(env.AzureTenantID), + ClientID: os.Getenv(env.KubeloginClientID), + ClientSecret: os.Getenv(env.KubeloginClientSecret), + ClientCert: os.Getenv(env.KubeloginClientCertificatePath), + ClientCertPassword: os.Getenv(env.KubeloginClientCertificatePassword), + AuthorityHost: os.Getenv(env.AzureAuthorityHost), + FederatedTokenFile: os.Getenv(env.AzureFederatedTokenFile), + } + + // azure variant overrides + if v, ok := os.LookupEnv(env.AzureClientID); ok { + rv.ClientID = v + } + if v, ok := os.LookupEnv(env.AzureClientSecret); ok { + rv.ClientSecret = v + } + if v, ok := os.LookupEnv(env.AzureClientCertificatePath); ok { + rv.ClientCert = v + } + if v, ok := os.LookupEnv(env.AzureClientCertificatePassword); ok { + rv.ClientCertPassword = v + } + + return rv +} + +func (opts *Options) toInternalOptions() *token.Options { + return &token.Options{ + LoginMethod: opts.LoginMethod, + Environment: opts.Environment, + TenantID: opts.TenantID, + ServerID: opts.ServerID, + ClientID: opts.ClientID, + ClientSecret: opts.ClientSecret, + ClientCert: opts.ClientCert, + ClientCertPassword: opts.ClientCertPassword, + IsPoPTokenEnabled: opts.IsPoPTokenEnabled, + PoPTokenClaims: opts.PoPTokenClaims, + IdentityResourceID: opts.IdentityResourceID, + AuthorityHost: opts.AuthorityHost, + FederatedTokenFile: opts.FederatedTokenFile, + } +} diff --git a/pkg/token/options_ctor_test.go b/pkg/token/options_ctor_test.go new file mode 100644 index 00000000..be30a187 --- /dev/null +++ b/pkg/token/options_ctor_test.go @@ -0,0 +1,154 @@ +package token + +import ( + "reflect" + "testing" + + "github.com/Azure/kubelogin/pkg/internal/env" + "github.com/Azure/kubelogin/pkg/internal/token" + "github.com/stretchr/testify/assert" +) + +func TestOptionsWithEnv(t *testing.T) { + t.Run("no env vars", func(t *testing.T) { + o := OptionsWithEnv() + assert.Equal(t, &Options{}, o) + }) + + t.Run("with kubelogin variant env vars", func(t *testing.T) { + for k, v := range map[string]string{ + env.LoginMethod: MSILogin, + env.AzureTenantID: "tenant-id", + env.KubeloginClientID: "client-id", + env.KubeloginClientSecret: "client-secret", + env.KubeloginClientCertificatePath: "client-cert-path", + env.KubeloginClientCertificatePassword: "client-cert-password", + env.AzureAuthorityHost: "authority-host", + env.AzureFederatedTokenFile: "federated-token-file", + } { + t.Setenv(k, v) + } + + o := OptionsWithEnv() + assert.Equal(t, &Options{ + LoginMethod: MSILogin, + TenantID: "tenant-id", + ClientID: "client-id", + ClientSecret: "client-secret", + ClientCert: "client-cert-path", + ClientCertPassword: "client-cert-password", + AuthorityHost: "authority-host", + FederatedTokenFile: "federated-token-file", + }, o) + }) + + t.Run("with azure variant env vars", func(t *testing.T) { + for k, v := range map[string]string{ + env.LoginMethod: MSILogin, + env.AzureTenantID: "tenant-id", + env.KubeloginClientID: "client-id", + env.AzureClientID: "azure-client-id", + env.KubeloginClientSecret: "client-secret", + env.AzureClientSecret: "azure-client-secret", + env.KubeloginClientCertificatePath: "client-cert-path", + env.AzureClientCertificatePath: "azure-client-cert-path", + env.KubeloginClientCertificatePassword: "client-cert-password", + env.AzureClientCertificatePassword: "azure-client-cert-password", + env.AzureAuthorityHost: "authority-host", + env.AzureFederatedTokenFile: "federated-token-file", + } { + t.Setenv(k, v) + } + + o := OptionsWithEnv() + assert.Equal(t, &Options{ + LoginMethod: MSILogin, + TenantID: "tenant-id", + ClientID: "azure-client-id", + ClientSecret: "azure-client-secret", + ClientCert: "azure-client-cert-path", + ClientCertPassword: "azure-client-cert-password", + AuthorityHost: "authority-host", + FederatedTokenFile: "federated-token-file", + }, o) + }) +} + +func TestOptions_toInternalOptions(t *testing.T) { + t.Run("basic", func(t *testing.T) { + o := &Options{ + LoginMethod: "login-method", + Environment: "environment", + TenantID: "tenant-id", + ServerID: "server-id", + ClientID: "client-id", + ClientSecret: "client-secret", + ClientCert: "client-cert", + ClientCertPassword: "client-cert-password", + IsPoPTokenEnabled: true, + PoPTokenClaims: "pop-token-claims", + IdentityResourceID: "identity-resource-id", + AuthorityHost: "authority-host", + FederatedTokenFile: "federated-token-file", + } + assert.Equal(t, &token.Options{ + LoginMethod: "login-method", + Environment: "environment", + TenantID: "tenant-id", + ServerID: "server-id", + ClientID: "client-id", + ClientSecret: "client-secret", + ClientCert: "client-cert", + ClientCertPassword: "client-cert-password", + IsPoPTokenEnabled: true, + PoPTokenClaims: "pop-token-claims", + IdentityResourceID: "identity-resource-id", + AuthorityHost: "authority-host", + FederatedTokenFile: "federated-token-file", + }, o.toInternalOptions()) + }) + + // this test uses reflection to ensure all fields in *Options + // are copied to *token.Options without modification. + t.Run("fields assignment", func(t *testing.T) { + boolValue := true + stringValue := "string-value" + + o := &Options{} + + // fill up all fields in *Options + oType := reflect.TypeOf(o).Elem() + oValue := reflect.ValueOf(o).Elem() + for i := 0; i < oValue.NumField(); i++ { + fieldValue := oValue.Field(i) + fieldType := oType.Field(i) + switch k := fieldType.Type.Kind(); k { + case reflect.Bool: + // set bool value + fieldValue.SetBool(boolValue) + case reflect.String: + fieldValue.SetString(stringValue) + default: + t.Errorf("unexpected type: %s", k) + } + } + + internalOpts := o.toInternalOptions() + assert.NotNil(t, internalOpts) + + internalOptsValue := reflect.ValueOf(internalOpts).Elem() + for i := 0; i < oValue.NumField(); i++ { + fieldType := oType.Field(i) + t.Log(fieldType.Name) + internalOptsFieldValue := internalOptsValue.FieldByName(fieldType.Name) + switch k := fieldType.Type.Kind(); k { + case reflect.Bool: + assert.Equal(t, boolValue, internalOptsFieldValue.Bool(), "field: %s", fieldType.Name) + case reflect.String: + assert.Equal(t, stringValue, internalOptsFieldValue.String(), "field: %s", fieldType.Name) + default: + t.Errorf("unexpected type: %s", k) + } + } + }) +} diff --git a/pkg/token/provider.go b/pkg/token/provider.go new file mode 100644 index 00000000..479d9ce1 --- /dev/null +++ b/pkg/token/provider.go @@ -0,0 +1,41 @@ +package token + +import ( + "context" + + "github.com/Azure/kubelogin/pkg/internal/token" +) + +type tokenProviderShim struct { + impl token.TokenProvider +} + +var _ TokenProvider = (*tokenProviderShim)(nil) + +func (tp *tokenProviderShim) GetAccessToken(ctx context.Context) (AccessToken, error) { + t, err := tp.impl.Token(ctx) + if err != nil { + return AccessToken{}, err + } + + rv := AccessToken{ + Token: t.AccessToken, + ExpiresOn: t.Expires(), + } + + return rv, nil +} + +// GetTokenProvider returns a token provider based on the given options. +func GetTokenProvider(options *Options) (TokenProvider, error) { + impl, err := token.NewTokenProvider(options.toInternalOptions()) + if err != nil { + return nil, err + } + + rv := &tokenProviderShim{ + impl: impl, + } + + return rv, nil +} diff --git a/pkg/token/provider_test.go b/pkg/token/provider_test.go new file mode 100644 index 00000000..4bc41e00 --- /dev/null +++ b/pkg/token/provider_test.go @@ -0,0 +1,74 @@ +package token + +import ( + "context" + "encoding/json" + "testing" + + "github.com/Azure/go-autorest/autorest/adal" + "github.com/Azure/kubelogin/pkg/internal/token/mock_token" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" +) + +func TestGetTokenProvider(t *testing.T) { + t.Run("invalid login method", func(t *testing.T) { + opts := &Options{ + LoginMethod: "invalid-login-method", + } + tp, err := GetTokenProvider(opts) + assert.Error(t, err) + assert.Nil(t, tp) + }) + + t.Run("basic", func(t *testing.T) { + opts := &Options{ + LoginMethod: MSILogin, + ClientID: "client-id", + IdentityResourceID: "identity-resource-id", + ServerID: "server-id", + } + tp, err := GetTokenProvider(opts) + assert.NoError(t, err) + assert.NotNil(t, tp) + }) +} + +func TestTokenProviderShim_GetAccessToken(t *testing.T) { + t.Run("failure case", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + mockTokenProvider := mock_token.NewMockTokenProvider(mockCtrl) + mockTokenProvider.EXPECT().Token(gomock.Any()).Return(adal.Token{}, assert.AnError) + + tp := &tokenProviderShim{ + impl: mockTokenProvider, + } + + token, err := tp.GetAccessToken(context.Background()) + assert.Equal(t, AccessToken{}, token) + assert.Equal(t, assert.AnError, err) + }) + + t.Run("success case", func(t *testing.T) { + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + + adalToken := adal.Token{ + AccessToken: "access-token", + ExpiresOn: json.Number("1700000000"), + } + mockTokenProvider := mock_token.NewMockTokenProvider(mockCtrl) + mockTokenProvider.EXPECT().Token(gomock.Any()).Return(adalToken, nil) + + tp := &tokenProviderShim{ + impl: mockTokenProvider, + } + + token, err := tp.GetAccessToken(context.Background()) + assert.NoError(t, err) + assert.Equal(t, adalToken.AccessToken, token.Token) + assert.Equal(t, adalToken.Expires(), token.ExpiresOn) + }) +} diff --git a/pkg/token/types.go b/pkg/token/types.go new file mode 100644 index 00000000..c03480d5 --- /dev/null +++ b/pkg/token/types.go @@ -0,0 +1,16 @@ +package token + +import ( + "context" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" +) + +// AccessToken represents an Azure service bearer access token with expiry information. +type AccessToken = azcore.AccessToken + +// TokenProvider provides access to tokens. +type TokenProvider interface { + // GetAccessToken returns an access token from given settings. + GetAccessToken(ctx context.Context) (AccessToken, error) +} diff --git a/version_test.go b/version_test.go index 913799a8..9df4ca36 100644 --- a/version_test.go +++ b/version_test.go @@ -8,4 +8,4 @@ func Test_loadVersion(t *testing.T) { if versionString == "" { t.Errorf("versionString is empty") } -} \ No newline at end of file +}