Skip to content

Commit

Permalink
Add Clone() method for arm/policy.ClientOptions (#20288)
Browse files Browse the repository at this point in the history
* fix empty policy copy problem for arm/runtime.NewPipeline

* add Copy() method for arm/policy.ClientOptions

* rename Copy() to Clone() and fix ci failure
  • Loading branch information
tadelesh authored Feb 27, 2023
1 parent 15aa35d commit 29ba214
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 2 deletions.
2 changes: 2 additions & 0 deletions sdk/azcore/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
## 1.3.2 (Unreleased)

### Features Added
* Add `Clone()` method for `arm/policy.ClientOptions`.

### Breaking Changes

### Bugs Fixed
* ARM's RP registration policy will no longer swallow unrecognized errors.
* Fixed an issue in `runtime.NewPollerFromResumeToken()` when resuming a `Poller` with a custom `PollingHandler`.
* Fixed wrong policy copy in `arm/runtime.NewPipeline()`.

### Other Changes

Expand Down
37 changes: 37 additions & 0 deletions sdk/azcore/arm/policy/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,40 @@ type ClientOptions struct {
// DisableRPRegistration disables the auto-RP registration policy. Defaults to false.
DisableRPRegistration bool
}

// Clone return a deep copy of the current options.
func (o *ClientOptions) Clone() *ClientOptions {
if o == nil {
return nil
}
copiedOptions := *o
copiedOptions.Cloud.Services = copyMap(copiedOptions.Cloud.Services)
copiedOptions.Logging.AllowedHeaders = copyArray(copiedOptions.Logging.AllowedHeaders)
copiedOptions.Logging.AllowedQueryParams = copyArray(copiedOptions.Logging.AllowedQueryParams)
copiedOptions.Retry.StatusCodes = copyArray(copiedOptions.Retry.StatusCodes)
copiedOptions.PerRetryPolicies = copyArray(copiedOptions.PerRetryPolicies)
copiedOptions.PerCallPolicies = copyArray(copiedOptions.PerCallPolicies)
return &copiedOptions
}

// copyMap return a new map with all the key value pair in the src map
func copyMap[K comparable, V any](src map[K]V) map[K]V {
if src == nil {
return nil
}
copiedMap := make(map[K]V)
for k, v := range src {
copiedMap[k] = v
}
return copiedMap
}

// copyMap return a new array with all the elements in the src array
func copyArray[T any](src []T) []T {
if src == nil {
return nil
}
copiedArray := make([]T, len(src))
copy(copiedArray, src)
return copiedArray
}
47 changes: 47 additions & 0 deletions sdk/azcore/arm/policy/policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
//go:build go1.18
// +build go1.18

// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

package policy

import (
"fmt"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
"github.com/stretchr/testify/require"
"testing"
)

func TestClientOptions_Copy(t *testing.T) {
var option *ClientOptions
require.Nil(t, option.Clone())

option = &ClientOptions{ClientOptions: policy.ClientOptions{
Cloud: cloud.AzurePublic,
Logging: policy.LogOptions{
AllowedHeaders: []string{"test1", "test2"},
AllowedQueryParams: []string{"test1", "test2"},
},
Retry: policy.RetryOptions{StatusCodes: []int{1, 2}},
PerRetryPolicies: []policy.Policy{runtime.NewLogPolicy(nil)},
PerCallPolicies: []policy.Policy{runtime.NewLogPolicy(nil)},
}}
copiedOption := option.Clone()
require.Equal(t, option.APIVersion, copiedOption.APIVersion)
require.NotEqual(t, fmt.Sprintf("%p", &option.APIVersion), fmt.Sprintf("%p", &copiedOption.APIVersion))
require.Equal(t, option.Cloud.Services, copiedOption.Cloud.Services)
require.NotEqual(t, fmt.Sprintf("%p", option.Cloud.Services), fmt.Sprintf("%p", copiedOption.Cloud.Services))
require.Equal(t, option.Logging.AllowedHeaders, copiedOption.Logging.AllowedHeaders)
require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedHeaders), fmt.Sprintf("%p", copiedOption.Logging.AllowedHeaders))
require.Equal(t, option.Logging.AllowedQueryParams, copiedOption.Logging.AllowedQueryParams)
require.NotEqual(t, fmt.Sprintf("%p", option.Logging.AllowedQueryParams), fmt.Sprintf("%p", copiedOption.Logging.AllowedQueryParams))
require.Equal(t, option.Retry.StatusCodes, copiedOption.Retry.StatusCodes)
require.NotEqual(t, fmt.Sprintf("%p", option.Retry.StatusCodes), fmt.Sprintf("%p", copiedOption.Retry.StatusCodes))
require.Equal(t, option.PerRetryPolicies, copiedOption.PerRetryPolicies)
require.NotEqual(t, fmt.Sprintf("%p", option.PerRetryPolicies), fmt.Sprintf("%p", copiedOption.PerRetryPolicies))
require.Equal(t, option.PerCallPolicies, copiedOption.PerCallPolicies)
require.NotEqual(t, fmt.Sprintf("%p", option.PerCallPolicies), fmt.Sprintf("%p", copiedOption.PerCallPolicies))
}
4 changes: 2 additions & 2 deletions sdk/azcore/arm/runtime/pipeline.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
return azruntime.Pipeline{}, err
}
authPolicy := NewBearerTokenPolicy(cred, &armpolicy.BearerTokenOptions{Scopes: []string{conf.Audience + "/.default"}})
perRetry := make([]azpolicy.Policy, 0, len(plOpts.PerRetry)+1)
perRetry := make([]azpolicy.Policy, len(plOpts.PerRetry), len(plOpts.PerRetry)+1)
copy(perRetry, plOpts.PerRetry)
plOpts.PerRetry = append(perRetry, authPolicy)
if !options.DisableRPRegistration {
Expand All @@ -38,7 +38,7 @@ func NewPipeline(module, version string, cred azcore.TokenCredential, plOpts azr
if err != nil {
return azruntime.Pipeline{}, err
}
perCall := make([]azpolicy.Policy, 0, len(plOpts.PerCall)+1)
perCall := make([]azpolicy.Policy, len(plOpts.PerCall), len(plOpts.PerCall)+1)
copy(perCall, plOpts.PerCall)
plOpts.PerCall = append(perCall, regPolicy)
}
Expand Down

0 comments on commit 29ba214

Please sign in to comment.