diff --git a/sdk/azcore/CHANGELOG.md b/sdk/azcore/CHANGELOG.md index 999b2cdda712..df3c7b9338d8 100644 --- a/sdk/azcore/CHANGELOG.md +++ b/sdk/azcore/CHANGELOG.md @@ -3,12 +3,14 @@ ## 1.3.2 (Unreleased) ### Features Added +* Add `Clone()` method for `arm/policy.ClientOptions`. ### Breaking Changes ### Bugs Fixed * ARM's RP registration policy will no longer swallow unrecognized errors. * Fixed an issue in `runtime.NewPollerFromResumeToken()` when resuming a `Poller` with a custom `PollingHandler`. +* Fixed wrong policy copy in `arm/runtime.NewPipeline()`. ### Other Changes diff --git a/sdk/azcore/arm/policy/policy.go b/sdk/azcore/arm/policy/policy.go index 098aa6993467..7a700d661e9e 100644 --- a/sdk/azcore/arm/policy/policy.go +++ b/sdk/azcore/arm/policy/policy.go @@ -47,3 +47,40 @@ type ClientOptions struct { // DisableRPRegistration disables the auto-RP registration policy. Defaults to false. DisableRPRegistration bool } + +// Clone return a deep copy of the current options. +func (o *ClientOptions) Clone() *ClientOptions { + if o == nil { + return nil + } + copiedOptions := *o + copiedOptions.Cloud.Services = copyMap(copiedOptions.Cloud.Services) + copiedOptions.Logging.AllowedHeaders = copyArray(copiedOptions.Logging.AllowedHeaders) + copiedOptions.Logging.AllowedQueryParams = copyArray(copiedOptions.Logging.AllowedQueryParams) + copiedOptions.Retry.StatusCodes = copyArray(copiedOptions.Retry.StatusCodes) + copiedOptions.PerRetryPolicies = copyArray(copiedOptions.PerRetryPolicies) + copiedOptions.PerCallPolicies = copyArray(copiedOptions.PerCallPolicies) + return &copiedOptions +} + +// copyMap return a new map with all the key value pair in the src map +func copyMap[K comparable, V any](src map[K]V) map[K]V { + if src == nil { + return nil + } + copiedMap := make(map[K]V) + for k, v := range src { + copiedMap[k] = v + } + return copiedMap +} + +// copyMap return a new array with all the elements in the src array +func copyArray[T any](src []T) []T { + if src == nil { + return nil + } + copiedArray := make([]T, len(src)) + copy(copiedArray, src) + return copiedArray +} diff --git a/sdk/azcore/arm/policy/policy_test.go b/sdk/azcore/arm/policy/policy_test.go new file mode 100644 index 000000000000..7c171daeb7d4 --- /dev/null +++ b/sdk/azcore/arm/policy/policy_test.go @@ -0,0 +1,47 @@ +//go:build go1.18 +// +build go1.18 + +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +package policy + +import ( + "fmt" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" + "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" + "github.com/stretchr/testify/require" + "testing" +) + +func TestClientOptions_Copy(t *testing.T) { + var option *ClientOptions + require.Nil(t, option.Clone()) + + option = &ClientOptions{ClientOptions: policy.ClientOptions{ + Cloud: cloud.AzurePublic, + Logging: policy.LogOptions{ + AllowedHeaders: []string{"test1", "test2"}, + AllowedQueryParams: []string{"test1", "test2"}, + }, + Retry: policy.RetryOptions{StatusCodes: []int{1, 2}}, + PerRetryPolicies: []policy.Policy{runtime.NewLogPolicy(nil)}, + PerCallPolicies: []policy.Policy{runtime.NewLogPolicy(nil)}, + }} + copiedOption := option.Clone() + require.Equal(t, option.APIVersion, copiedOption.APIVersion) + require.NotEqual(t, fmt.Sprintf("%p", &option.APIVersion), fmt.Sprintf("%p", &copiedOption.APIVersion)) + require.Equal(t, option.Cloud.Services, copiedOption.Cloud.Services) + require.NotEqual(t, fmt.Sprintf("%p", option.Cloud.Services), fmt.Sprintf("%p", copiedOption.Cloud.Services)) + require.Equal(t, option.Logging.AllowedHeaders, copiedOption.Logging.AllowedHeaders) + require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedHeaders), fmt.Sprintf("%p", copiedOption.Logging.AllowedHeaders)) + require.Equal(t, option.Logging.AllowedQueryParams, copiedOption.Logging.AllowedQueryParams) + require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedQueryParams), fmt.Sprintf("%p", copiedOption.Logging.AllowedQueryParams)) + require.Equal(t, option.Retry.StatusCodes, copiedOption.Retry.StatusCodes) + require.NotEqual(t, fmt.Sprintf("%p", option.Retry.StatusCodes), fmt.Sprintf("%p", copiedOption.Retry.StatusCodes)) + require.Equal(t, option.PerRetryPolicies, copiedOption.PerRetryPolicies) + require.NotEqual(t, fmt.Sprintf("%p", option.PerRetryPolicies), fmt.Sprintf("%p", copiedOption.PerRetryPolicies)) + require.Equal(t, option.PerCallPolicies, copiedOption.PerCallPolicies) + require.NotEqual(t, fmt.Sprintf("%p", option.PerCallPolicies), fmt.Sprintf("%p", copiedOption.PerCallPolicies)) +} diff --git a/sdk/azcore/arm/runtime/pipeline.go b/sdk/azcore/arm/runtime/pipeline.go index a2e897765d98..8da2153307b2 100644 --- a/sdk/azcore/arm/runtime/pipeline.go +++ b/sdk/azcore/arm/runtime/pipeline.go @@ -29,7 +29,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr return azruntime.Pipeline{}, err } authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{conf.Audience + "/.default"}}) - perRetry := make([]azpolicy.Policy, 0, len(plOpts.PerRetry)+1) + perRetry := make([]azpolicy.Policy, len(plOpts.PerRetry), len(plOpts.PerRetry)+1) copy(perRetry, plOpts.PerRetry) plOpts.PerRetry = append(perRetry, authPolicy) if !options.DisableRPRegistration { @@ -38,7 +38,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr if err != nil { return azruntime.Pipeline{}, err } - perCall := make([]azpolicy.Policy, 0, len(plOpts.PerCall)+1) + perCall := make([]azpolicy.Policy, len(plOpts.PerCall), len(plOpts.PerCall)+1) copy(perCall, plOpts.PerCall) plOpts.PerCall = append(perCall, regPolicy) }