diff --git a/pkg/blockstorage/azure/azuredisk.go b/pkg/blockstorage/azure/azuredisk.go index ee4309616c..dbdf7025f7 100644 --- a/pkg/blockstorage/azure/azuredisk.go +++ b/pkg/blockstorage/azure/azuredisk.go @@ -653,7 +653,7 @@ func (s *AdStorage) SnapshotRestoreTargets(ctx context.Context, snapshot *blocks // dynamicRegionMapAzure derives a mapping from Regions to zones for Azure. Depends on subscriptionID func (s *AdStorage) dynamicRegionMapAzure(ctx context.Context) (map[string][]string, error) { - subscriptionsCLient := subscriptions.NewClient() + subscriptionsCLient := subscriptions.NewClientWithBaseURI(s.azCli.BaseURI) subscriptionsCLient.Authorizer = s.azCli.Authorizer llResp, err := subscriptionsCLient.ListLocations(ctx, s.azCli.SubscriptionID) if err != nil { @@ -664,7 +664,7 @@ func (s *AdStorage) dynamicRegionMapAzure(ctx context.Context) (map[string][]str regionMap[*location.Name] = make(map[string]struct{}) } - skuClient := skus.NewResourceSkusClient(s.azCli.SubscriptionID) + skuClient := skus.NewResourceSkusClientWithBaseURI(s.azCli.BaseURI, s.azCli.SubscriptionID) skuClient.Authorizer = s.azCli.Authorizer skuResults, err := skuClient.ListComplete(ctx) if err != nil { diff --git a/pkg/blockstorage/azure/client.go b/pkg/blockstorage/azure/client.go index 3da068bb9d..d8ae687bf8 100644 --- a/pkg/blockstorage/azure/client.go +++ b/pkg/blockstorage/azure/client.go @@ -79,16 +79,14 @@ func NewClient(ctx context.Context, config map[string]string) (*Client, error) { baseURI, ok = config[blockstorage.AzureResurceMgrEndpoint] if !ok { - baseURI = compute.DefaultBaseURI + baseURI = env.ResourceManagerEndpoint } - disksClient := compute.NewDisksClient(subscriptionID) + disksClient := compute.NewDisksClientWithBaseURI(baseURI, subscriptionID) disksClient.Authorizer = authorizer - disksClient.BaseURI = baseURI - snapshotsClient := compute.NewSnapshotsClient(subscriptionID) + snapshotsClient := compute.NewSnapshotsClientWithBaseURI(baseURI, subscriptionID) snapshotsClient.Authorizer = authorizer - snapshotsClient.BaseURI = baseURI return &Client{ BaseURI: baseURI, @@ -102,38 +100,46 @@ func NewClient(ctx context.Context, config map[string]string) (*Client, error) { // nolint:unparam func getAuthorizer(env azure.Environment, config map[string]string) (*autorest.BearerAuthorizer, error) { + credConfig, err := getCredConfig(env, config) + if err != nil { + return nil, errors.Wrap(err, "Failed to get Azure ClientCredentialsConfig") + } + a, err := credConfig.Authorizer() + if err != nil { + return nil, errors.Wrap(err, "Failed to get Azure authorizer") + } + + ba, ok := a.(*autorest.BearerAuthorizer) + if !ok { + return nil, errors.New("Failed to get Azure authorizer") + } + return ba, nil +} + +func getCredConfig(env azure.Environment, config map[string]string) (auth.ClientCredentialsConfig, error) { tenantID, ok := config[blockstorage.AzureTenantID] if !ok { - return nil, errors.New("Cannot get tenantID from config") + return auth.ClientCredentialsConfig{}, errors.New("Cannot get tenantID from config") } clientID, ok := config[blockstorage.AzureCientID] if !ok { - return nil, errors.New("Cannot get clientID from config") + return auth.ClientCredentialsConfig{}, errors.New("Cannot get clientID from config") } clientSecret, ok := config[blockstorage.AzureClentSecret] if !ok { - return nil, errors.New("Cannot get clientSecret from config") + return auth.ClientCredentialsConfig{}, errors.New("Cannot get clientSecret from config") } credConfig := auth.NewClientCredentialsConfig(clientID, clientSecret, tenantID) - if aDDEndpoint, ok := config[blockstorage.AzureActiveDirEndpoint]; ok { - credConfig.AADEndpoint = aDDEndpoint - } - - if aDDResourceID, ok := config[blockstorage.AzureActiveDirResourceID]; ok { - credConfig.Resource = aDDResourceID + if credConfig.AADEndpoint, ok = config[blockstorage.AzureActiveDirEndpoint]; !ok || credConfig.AADEndpoint == "" { + credConfig.AADEndpoint = env.ActiveDirectoryEndpoint } - a, err := credConfig.Authorizer() - if err != nil { - return nil, errors.Wrap(err, "Failed to get Azure authorizer") + if credConfig.Resource, ok = config[blockstorage.AzureActiveDirResourceID]; !ok || credConfig.Resource == "" { + credConfig.Resource = env.ResourceManagerEndpoint } - ba, ok := a.(*autorest.BearerAuthorizer) - if !ok { - return nil, errors.New("Failed to get Azure authorizer") - } - return ba, nil + return credConfig, nil } diff --git a/pkg/blockstorage/azure/client_test.go b/pkg/blockstorage/azure/client_test.go index 5efbe8dd22..85eeb80f3d 100644 --- a/pkg/blockstorage/azure/client_test.go +++ b/pkg/blockstorage/azure/client_test.go @@ -20,6 +20,8 @@ import ( "strings" "testing" + "github.com/Azure/go-autorest/autorest/azure" + "github.com/Azure/go-autorest/autorest/azure/auth" "github.com/kanisterio/kanister/pkg/blockstorage" envconfig "github.com/kanisterio/kanister/pkg/config" . "gopkg.in/check.v1" @@ -82,3 +84,95 @@ func (s ClientSuite) TestGetRegions(c *C) { c.Assert(err, IsNil) c.Assert(regions, NotNil) } + +func (s *ClientSuite) TestGetCredConfig(c *C) { + for _, tc := range []struct { + env azure.Environment + config map[string]string + errChecker Checker + expCCC auth.ClientCredentialsConfig + }{ + { + env: azure.PublicCloud, + config: map[string]string{ + blockstorage.AzureTenantID: "atid", + blockstorage.AzureCientID: "acid", + blockstorage.AzureClentSecret: "acs", + blockstorage.AzureActiveDirEndpoint: "aade", + blockstorage.AzureActiveDirResourceID: "aadrid", + }, + expCCC: auth.ClientCredentialsConfig{ + ClientID: "acid", + ClientSecret: "acs", + TenantID: "atid", + Resource: "aadrid", + AADEndpoint: "aade", + }, + errChecker: IsNil, + }, + { + env: azure.PublicCloud, + config: map[string]string{ + blockstorage.AzureTenantID: "atid", + blockstorage.AzureCientID: "acid", + blockstorage.AzureClentSecret: "acs", + }, + expCCC: auth.ClientCredentialsConfig{ + ClientID: "acid", + ClientSecret: "acs", + TenantID: "atid", + Resource: azure.PublicCloud.ResourceManagerEndpoint, + AADEndpoint: azure.PublicCloud.ActiveDirectoryEndpoint, + }, + errChecker: IsNil, + }, + { + env: azure.USGovernmentCloud, + config: map[string]string{ + blockstorage.AzureTenantID: "atid", + blockstorage.AzureCientID: "acid", + blockstorage.AzureClentSecret: "acs", + blockstorage.AzureActiveDirEndpoint: "", + blockstorage.AzureActiveDirResourceID: "", + }, + expCCC: auth.ClientCredentialsConfig{ + ClientID: "acid", + ClientSecret: "acs", + TenantID: "atid", + Resource: azure.USGovernmentCloud.ResourceManagerEndpoint, + AADEndpoint: azure.USGovernmentCloud.ActiveDirectoryEndpoint, + }, + errChecker: IsNil, + }, + { + env: azure.USGovernmentCloud, + config: map[string]string{ + blockstorage.AzureTenantID: "atid", + blockstorage.AzureCientID: "acid", + }, + errChecker: NotNil, + }, + { + env: azure.USGovernmentCloud, + config: map[string]string{ + blockstorage.AzureTenantID: "atid", + }, + errChecker: NotNil, + }, + { + env: azure.USGovernmentCloud, + config: map[string]string{}, + errChecker: NotNil, + }, + } { + ccc, err := getCredConfig(tc.env, tc.config) + c.Assert(err, tc.errChecker) + if err == nil { + c.Assert(ccc.ClientID, Equals, tc.expCCC.ClientID) + c.Assert(ccc.ClientSecret, Equals, tc.expCCC.ClientSecret) + c.Assert(ccc.TenantID, Equals, tc.expCCC.TenantID) + c.Assert(ccc.Resource, Equals, tc.expCCC.Resource) + c.Assert(ccc.AADEndpoint, Equals, tc.expCCC.AADEndpoint) + } + } +}