Skip to content

Commit

Permalink
feat: support env ALIBABA_CLOUD_STS_REGION for sts endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
yndu13 authored and peze committed Oct 28, 2024
1 parent a235167 commit 47c2eab
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
2 changes: 2 additions & 0 deletions credentials/internal/providers/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ func (b *OIDCCredentialsProviderBuilder) Build() (provider *OIDCCredentialsProvi
if b.provider.stsEndpoint == "" {
if b.provider.stsRegionId != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", b.provider.stsRegionId)
} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
b.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", region)
} else {
b.provider.stsEndpoint = "sts.aliyuncs.com"
}
Expand Down
10 changes: 9 additions & 1 deletion credentials/internal/providers/oidc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ func TestOIDCCredentialsProviderGetCredentialsWithError(t *testing.T) {
}

func TestNewOIDCCredentialsProvider(t *testing.T) {
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN")
rollback := utils.Memory("ALIBABA_CLOUD_OIDC_TOKEN_FILE", "ALIBABA_CLOUD_OIDC_PROVIDER_ARN", "ALIBABA_CLOUD_ROLE_ARN", "ALIBABA_CLOUD_STS_REGION")
defer func() {
rollback()
}()
Expand Down Expand Up @@ -89,6 +89,14 @@ func TestNewOIDCCredentialsProvider(t *testing.T) {
assert.Equal(t, "role_arn_from_env", p.roleArn)
// sts endpoint: default
assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)

// sts endpoint: with sts endpoint env
os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou")
p, err = NewOIDCCredentialsProviderBuilder().
Build()
assert.Nil(t, err)
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint: with sts endpoint
p, err = NewOIDCCredentialsProviderBuilder().
WithSTSEndpoint("sts.cn-shanghai.aliyuncs.com").
Expand Down
3 changes: 3 additions & 0 deletions credentials/internal/providers/ram_role_arn.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"net/http"
"net/url"
"os"
"strconv"
"strings"
"time"
Expand Down Expand Up @@ -146,6 +147,8 @@ func (builder *RAMRoleARNCredentialsProviderBuilder) Build() (provider *RAMRoleA
if builder.provider.stsEndpoint == "" {
if builder.provider.stsRegionId != "" {
builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", builder.provider.stsRegionId)
} else if region := os.Getenv("ALIBABA_CLOUD_STS_REGION"); region != "" {
builder.provider.stsEndpoint = fmt.Sprintf("sts.%s.aliyuncs.com", region)
} else {
builder.provider.stsEndpoint = "sts.aliyuncs.com"
}
Expand Down
29 changes: 24 additions & 5 deletions credentials/internal/providers/ram_role_arn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@ package providers

import (
"errors"
"os"
"strings"
"testing"
"time"

httputil "github.com/aliyun/credentials-go/credentials/internal/http"
"github.com/aliyun/credentials-go/credentials/internal/utils"
"github.com/stretchr/testify/assert"
)

func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
rollback := utils.Memory("ALIBABA_CLOUD_STS_REGION")
defer func() {
rollback()
}()
// case 1: no credentials provider
_, err := NewRAMRoleARNCredentialsProviderBuilder().
Build()
Expand Down Expand Up @@ -70,11 +76,10 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
// sts endpoint with sts region
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint with sts endpoint
// default sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
Expand All @@ -87,9 +92,10 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)
assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)

// default sts endpoint
// sts endpoint with env
os.Setenv("ALIBABA_CLOUD_STS_REGION", "cn-hangzhou")
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
Expand All @@ -99,13 +105,26 @@ func TestNewRAMRoleARNCredentialsProvider(t *testing.T) {
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "sts.cn-hangzhou.aliyuncs.com", p.stsEndpoint)

// sts endpoint with sts endpoint
p, err = NewRAMRoleARNCredentialsProviderBuilder().
WithCredentialsProvider(akProvider).
WithRoleArn("roleArn").
WithStsEndpoint("sts.cn-shanghai.aliyuncs.com").
WithPolicy("policy").
WithExternalId("externalId").
WithRoleSessionName("rsn").
WithDurationSeconds(1000).
Build()
assert.Nil(t, err)
assert.Equal(t, "rsn", p.roleSessionName)
assert.Equal(t, "roleArn", p.roleArn)
assert.Equal(t, "policy", p.policy)
assert.Equal(t, "externalId", p.externalId)
assert.Equal(t, "", p.stsRegionId)
assert.Equal(t, 1000, p.durationSeconds)
assert.Equal(t, "sts.aliyuncs.com", p.stsEndpoint)
assert.Equal(t, "sts.cn-shanghai.aliyuncs.com", p.stsEndpoint)
}

func TestRAMRoleARNCredentialsProvider_getCredentials(t *testing.T) {
Expand Down

0 comments on commit 47c2eab

Please sign in to comment.