Skip to content

Commit

Permalink
Azidentity migration for service principal token (#287)
Browse files Browse the repository at this point in the history
  • Loading branch information
ekoehn authored Jun 5, 2023
1 parent e2bacce commit 3a26c19
Show file tree
Hide file tree
Showing 11 changed files with 936 additions and 26 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ require (
golang.org/x/time v0.0.0-20220609170525-579cf78fd858 // indirect
google.golang.org/appengine v1.6.7 // indirect
google.golang.org/protobuf v1.28.1 // indirect
gopkg.in/dnaeon/go-vcr.v3 v3.1.2 // indirect
gopkg.in/inf.v0 v0.9.1 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8
gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
gopkg.in/dnaeon/go-vcr.v3 v3.1.2 h1:F1smfXBqQqwpVifDfUBQG6zzaGjzT+EnVZakrOdr5wA=
gopkg.in/dnaeon/go-vcr.v3 v3.1.2/go.mod h1:2IMOnnlx9I6u9x+YBsM3tAMx6AlOxnJ0pWxQAzZ79Ag=
gopkg.in/inf.v0 v0.9.1 h1:73M5CoZyi3ZLMOyDlQh031Cx6N9NDJ2Vvfl76EDAgDc=
gopkg.in/inf.v0 v0.9.1/go.mod h1:cWUDdTG/fYaXco+Dcufb5Vnc6Gp2YChqWtbxRZE0mXw=
gopkg.in/retry.v1 v1.0.3 h1:a9CArYczAVv6Qs6VGoLMio99GEs7kY9UzSF9+LD+iGs=
Expand Down
38 changes: 38 additions & 0 deletions pkg/token/README.MD
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
Instruction for recording response using [GO-VCR](https://github.com/dnaeon/go-vcr) for unit test
# Things to know if you want to record new recording
* All the recorded responses have been saved under folder `/testdata`

* Highly recommand using `RecordOnly` if you want completely new recording, otherwise, current recordings have been modified without the sensitive contents
* Here's the variable you need to input for recording
Modify these variables
modify authorizer clientID `AZURE_CLIENT_ID="<specify with real value>"`
modify authorizer clientSecret `AAD_SERVICE_PRINCIPAL_CLIENT_SECRET="<specify with real value>" `
modify authorizer clientCert `AZURE_CLIENT_CER="<specify with real value>"`
modify authorizer clientCertPass `AZURE_CLIENT_CERTIFICATE_PASSWORD="<specify with real value>" `
modify authorizer resourceID `AZURE_RESOURCE_ID="<specify with real value>"`
modify authorizer tenantID `AZURE_TENANT_ID="<specify with real value>" `
modify go-vcr record mode `VCR_MODE="<specify vcr mode>" `
you can set to record mode by setting vcr mode to RecordOnly `VCR_MODE="RecordOnly"`
To return to replay mode, simply unset the enviroment variable by `unset VCR_MODE`

Examples:
# Recording Mode
* Navigate to `pkg/token` folder in terminal
* Setup your enviroment variables

```
export AZURE_CLIENT_ID="<specify with real value>"
export AAD_SERVICE_PRINCIPAL_CLIENT_SECRET="<specify with real value>"
export AZURE_CLIENT_CER="<specify with real value>"
export AZURE_CLIENT_CERTIFICATE_PASSWORD="<specify with real value>"
export AZURE_RESOURCE_ID="<specify with real value>"
export AZURE_TENANT_ID="<specify with real value>"
export VCR_MODE="RecordOnly"
go test
```

# Replay Mode
```
unset VCR_MODE
go test
```
120 changes: 120 additions & 0 deletions pkg/token/govcrutils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package token

import (
"net/http"
"os"
"strings"

"gopkg.in/dnaeon/go-vcr.v3/cassette"
"gopkg.in/dnaeon/go-vcr.v3/recorder"
)

const (
tenantUUID = "AZURE_TENANT_ID"
vcrMode = "VCR_MODE"
vcrModeRecordOnly = "RecordOnly"
badSecret = "Bad_Secret"
redactionToken = "[REDACTED]"
testToken = "TEST_ACCESS_TOKEN"
)

// GetVCRHttpClient setup Go-vcr
func GetVCRHttpClient(path string, token string) (*recorder.Recorder, *http.Client) {
if len(path) == 0 || path == "" {
return nil, nil
}

opts := &recorder.Options{
CassetteName: path,
Mode: getVCRMode(),
}
rec, _ := recorder.NewWithOptions(opts)

hook := func(i *cassette.Interaction) error {
var detectedClientID, detectedClientSecret, detectedClientAssertion, detectedScope string
// Delete sensitive content
delete(i.Response.Headers, "Set-Cookie")
delete(i.Response.Headers, "X-Ms-Request-Id")
if i.Request.Form["client_id"] != nil {
detectedClientID = i.Request.Form["client_id"][0]
i.Request.Form["client_id"] = []string{redactionToken}
}
if i.Request.Form["client_secret"] != nil && i.Request.Form["client_secret"][0] != badSecret {
detectedClientSecret = i.Request.Form["client_secret"][0]
i.Request.Form["client_secret"] = []string{redactionToken}
}
if i.Request.Form["client_assertion"] != nil {
detectedClientAssertion = i.Request.Form["client_assertion"][0]
i.Request.Form["client_assertion"] = []string{redactionToken}
}
if i.Request.Form["scope"] != nil {
detectedScope = i.Request.Form["scope"][0][:strings.IndexByte(i.Request.Form["scope"][0], '/')]
i.Request.Form["scope"] = []string{redactionToken + "/.default openid offline_access profile"}
}
i.Request.URL = strings.ReplaceAll(i.Request.URL, os.Getenv(tenantUUID), tenantUUID)
i.Response.Body = strings.ReplaceAll(i.Response.Body, os.Getenv(tenantUUID), tenantUUID)

if detectedClientID != "" {
i.Request.Body = strings.ReplaceAll(i.Request.Body, detectedClientID, redactionToken)
}

if detectedClientSecret != "" {
i.Request.Body = strings.ReplaceAll(i.Request.Body, detectedClientSecret, redactionToken)
}

if detectedClientAssertion != "" {
i.Request.Body = strings.ReplaceAll(i.Request.Body, detectedClientAssertion, redactionToken)
}

if detectedScope != "" {
i.Request.Body = strings.ReplaceAll(i.Request.Body, detectedScope, redactionToken)
}

if strings.Contains(i.Response.Body, "access_token") {
i.Response.Body = `{"token_type":"Bearer","expires_in":86399,"ext_expires_in":86399,"access_token":"` + testToken + `"}`
}

if strings.Contains(i.Response.Body, "Invalid client secret provided") {
i.Response.Body = `{"error":"invalid_client","error_description":"AADSTS7000215: Invalid client secret provided. Ensure the secret being sent in the request is the client secret value, not the client secret ID, for a secret added to app ''[REDACTED]''.\r\nTrace ID: [REDACTED]\r\nCorrelation ID: [REDACTED]\r\nTimestamp: 2023-06-02 21:00:26Z","error_codes":[7000215],"timestamp":"2023-06-02 21:00:26Z","trace_id":"[REDACTED]","correlation_id":"[REDACTED]","error_uri":"https://login.microsoftonline.com/error?code=7000215"}`
}
return nil
}
rec.AddHook(hook, recorder.BeforeSaveHook)

playbackHook := func(i *cassette.Interaction) error {
// Return a verifiable unique token on each test
if strings.Contains(i.Response.Body, "access_token") {
i.Response.Body = strings.ReplaceAll(i.Response.Body, testToken, token)
}
return nil
}
rec.AddHook(playbackHook, recorder.BeforeResponseReplayHook)

rec.SetMatcher(customMatcher)
rec.SetReplayableInteractions(true)

return rec, rec.GetDefaultClient()
}

func customMatcher(r *http.Request, i cassette.Request) bool {
id := os.Getenv(tenantUUID)
if id == "" {
id = "00000000-0000-0000-0000-000000000000"
}
switch os.Getenv(vcrMode) {
case vcrModeRecordOnly:
default:
r.URL.Path = strings.ReplaceAll(r.URL.Path, id, tenantUUID)
}
return cassette.DefaultMatcher(r, i)
}

// Get go-vcr record mode from environment variable
func getVCRMode() recorder.Mode {
switch os.Getenv(vcrMode) {
case vcrModeRecordOnly:
return recorder.ModeRecordOnly
default:
return recorder.ModeReplayOnly
}
}
15 changes: 14 additions & 1 deletion pkg/token/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"

"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/go-autorest/autorest/adal"
"github.com/Azure/go-autorest/autorest/azure"
)
Expand All @@ -19,13 +20,17 @@ func newTokenProvider(o *Options) (TokenProvider, error) {
if err != nil {
return nil, fmt.Errorf("failed to get oAuthConfig. isLegacy: %t, err: %s", o.IsLegacy, err)
}
cloudConfiguration, err := getCloudConfig(o.Environment)
if err != nil {
return nil, fmt.Errorf("failed to get cloud.Configuration. err: %s", err)
}
switch o.LoginMethod {
case DeviceCodeLogin:
return newDeviceCodeTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID)
case InteractiveLogin:
return newInteractiveTokenProvider(*oAuthConfig, o.ClientID, o.ServerID, o.TenantID)
case ServicePrincipalLogin:
return newServicePrincipalToken(*oAuthConfig, o.ClientID, o.ClientSecret, o.ClientCert, o.ClientCertPassword, o.ServerID, o.TenantID)
return newServicePrincipalToken(cloudConfiguration, o.ClientID, o.ClientSecret, o.ClientCert, o.ClientCertPassword, o.ServerID, o.TenantID)
case ROPCLogin:
return newResourceOwnerToken(*oAuthConfig, o.ClientID, o.Username, o.Password, o.ServerID, o.TenantID)
case MSILogin:
Expand All @@ -39,6 +44,14 @@ func newTokenProvider(o *Options) (TokenProvider, error) {
return nil, errors.New("unsupported token provider")
}

func getCloudConfig(envName string) (cloud.Configuration, error) {
env, err := getAzureEnvironment(envName)
c := cloud.Configuration{
ActiveDirectoryAuthorityHost: env.ActiveDirectoryEndpoint,
}
return c, err
}

func getOAuthConfig(envName, tenantID string, isLegacy bool) (*adal.OAuthConfig, error) {
var (
oAuthConfig *adal.OAuthConfig
Expand Down
97 changes: 72 additions & 25 deletions pkg/token/serviceprincipaltoken.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,20 @@
package token

import (
"context"
"crypto/rsa"
"crypto/x509"
"encoding/json"
"encoding/pem"
"errors"
"fmt"
"os"
"strconv"

"github.com/Azure/azure-sdk-for-go/sdk/azcore"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azidentity"
"github.com/Azure/go-autorest/autorest/adal"
"golang.org/x/crypto/pkcs12"
)
Expand All @@ -24,10 +31,10 @@ type servicePrincipalToken struct {
clientCertPassword string
resourceID string
tenantID string
oAuthConfig adal.OAuthConfig
cloud cloud.Configuration
}

func newServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clientSecret, clientCert, clientCertPassword, resourceID, tenantID string) (TokenProvider, error) {
func newServicePrincipalToken(cloud cloud.Configuration, clientID, clientSecret, clientCert, clientCertPassword, resourceID, tenantID string) (TokenProvider, error) {
if clientID == "" {
return nil, errors.New("clientID cannot be empty")
}
Expand All @@ -51,32 +58,55 @@ func newServicePrincipalToken(oAuthConfig adal.OAuthConfig, clientID, clientSecr
clientCertPassword: clientCertPassword,
resourceID: resourceID,
tenantID: tenantID,
oAuthConfig: oAuthConfig,
cloud: cloud,
}, nil
}

// Token fetches an azcore.AccessToken from the Azure SDK and converts it to an adal.Token for use with kubelogin.
func (p *servicePrincipalToken) Token() (adal.Token, error) {
emptyToken := adal.Token{}
callback := func(t adal.Token) error {
return nil
}
return p.TokenWithOptions(nil)
}

var (
spt *adal.ServicePrincipalToken
err error
)
func (p *servicePrincipalToken) TokenWithOptions(options *azcore.ClientOptions) (adal.Token, error) {
emptyToken := adal.Token{}
var spnAccessToken azcore.AccessToken

// Request a new Azure token provider for service principal
if p.clientSecret != "" {
spt, err = adal.NewServicePrincipalToken(
p.oAuthConfig,
clientOptions := &azidentity.ClientSecretCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: p.cloud,
},
}
if options != nil {
clientOptions.ClientOptions = *options
}
cred, err := azidentity.NewClientSecretCredential(
p.tenantID,
p.clientID,
p.clientSecret,
p.resourceID,
callback)
clientOptions,
)
if err != nil {
return emptyToken, fmt.Errorf("unable to create credential. Received: %w", err)
}

// Use the token provider to get a new token
spnAccessToken, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{p.resourceID + "/.default"}})
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using secret: %s", err)
return emptyToken, fmt.Errorf("failed to create service principal token using secret: %w", err)
}

} else if p.clientCert != "" {
clientOptions := &azidentity.ClientCertificateCredentialOptions{
ClientOptions: azcore.ClientOptions{
Cloud: p.cloud,
},
SendCertificateChain: true,
}
if options != nil {
clientOptions.ClientOptions = *options
}
certData, err := os.ReadFile(p.clientCert)
if err != nil {
return emptyToken, fmt.Errorf("failed to read the certificate file (%s): %w", p.clientCert, err)
Expand All @@ -88,23 +118,40 @@ func (p *servicePrincipalToken) Token() (adal.Token, error) {
return emptyToken, fmt.Errorf("failed to decode pkcs12 certificate while creating spt: %w", err)
}

spt, err = adal.NewServicePrincipalTokenFromCertificate(
p.oAuthConfig,
cred, err := azidentity.NewClientCertificateCredential(
p.tenantID,
p.clientID,
cert,
[]*x509.Certificate{cert},
rsaPrivateKey,
p.resourceID,
callback)
clientOptions,
)
if err != nil {
return emptyToken, fmt.Errorf("unable to create credential. Received: %v", err)
}
spnAccessToken, err = cred.GetToken(context.Background(), policy.TokenRequestOptions{Scopes: []string{p.resourceID + "/.default"}})
if err != nil {
return emptyToken, fmt.Errorf("failed to create service principal token using cert: %s", err)
}

} else {
return emptyToken, errors.New("service principal token requires either client secret or certificate")
}

err = spt.Refresh()
if err != nil {
return emptyToken, err
if spnAccessToken.Token == "" {
return emptyToken, errors.New("unexpectedly got empty access token")
}
return spt.Token(), nil

// azurecore.AccessTokens have ExpiresOn as Time.Time. We need to convert it to JSON.Number
// by fetching the time in seconds since the Unix epoch via Unix() and then converting to a
// JSON.Number via formatting as a string using a base-10 int64 conversion.
expiresOn := json.Number(strconv.FormatInt(spnAccessToken.ExpiresOn.Unix(), 10))

// Re-wrap the azurecore.AccessToken into an adal.Token
return adal.Token{
AccessToken: spnAccessToken.Token,
ExpiresOn: expiresOn,
Resource: p.resourceID,
}, nil
}

func isPublicKeyEqual(key1, key2 *rsa.PublicKey) bool {
Expand Down
Loading

0 comments on commit 3a26c19

Please sign in to comment.