Skip to content

Commit

Permalink
Merge pull request #349 from AzureAD/release-0.7.0
Browse files Browse the repository at this point in the history
MSAL Go Release 0.7.0
  • Loading branch information
rayluo authored Sep 15, 2022
2 parents 8d382bd + 27390bb commit 031858c
Show file tree
Hide file tree
Showing 26 changed files with 236 additions and 86 deletions.
3 changes: 1 addition & 2 deletions .github/workflows/golangci-lint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ jobs:
uses: golangci/golangci-lint-action@v3
with:
# Required: the version of golangci-lint is required and must be specified without patch version: we always use the latest patch version.
version: v1.32
version: v1.48

# Optional: golangci-lint command line arguments.
# args: --issues-exit-code=0

36 changes: 35 additions & 1 deletion apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,8 @@ type Credential struct {
x5c []string

assertionCallback func(context.Context, AssertionRequestOptions) (string, error)

tokenProvider func(context.Context, TokenProviderParameters) (TokenProviderResult, error)
}

// toInternal returns the accesstokens.Credential that is used internally. The current structure of the
Expand All @@ -162,6 +164,9 @@ func (c Credential) toInternal() (*accesstokens.Credential, error) {
if c.assertionCallback != nil {
return &accesstokens.Credential{AssertionCallback: c.assertionCallback}, nil
}
if c.tokenProvider != nil {
return &accesstokens.Credential{TokenProvider: c.tokenProvider}, nil
}
return nil, errors.New("invalid credential")
}

Expand Down Expand Up @@ -226,6 +231,19 @@ func NewCredFromCertChain(certs []*x509.Certificate, key crypto.PrivateKey) (Cre
return cred, nil
}

// TokenProviderParameters is the authentication parameters passed to token providers
type TokenProviderParameters = exported.TokenProviderParameters

// TokenProviderResult is the authentication result returned by custom token providers
type TokenProviderResult = exported.TokenProviderResult

// NewCredFromTokenProvider creates a Credential from a function that provides access tokens. The function
// must be concurrency safe. This is intended only to allow the Azure SDK to cache MSI tokens. It isn't
// useful to applications in general because the token provider must implement all authentication logic.
func NewCredFromTokenProvider(provider func(context.Context, TokenProviderParameters) (TokenProviderResult, error)) Credential {
return Credential{tokenProvider: provider}
}

// AutoDetectRegion instructs MSAL Go to auto detect region for Azure regional token service.
func AutoDetectRegion() string {
return "TryAutoDetect"
Expand Down Expand Up @@ -348,7 +366,23 @@ func New(clientID string, cred Credential, options ...Option) (Client, error) {
return Client{}, err
}

base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), base.WithX5C(opts.SendX5C), base.WithCacheAccessor(opts.Accessor), base.WithRegionDetection(opts.AzureRegion))
baseOpts := []base.Option{
base.WithCacheAccessor(opts.Accessor),
base.WithRegionDetection(opts.AzureRegion),
base.WithX5C(opts.SendX5C),
}
if cred.tokenProvider != nil {
// The caller will handle all details of authentication, using Client only as a token cache.
// Declaring the authority host known prevents unnecessary metadata discovery requests. (The
// authority is irrelevant to Client and friends because the token provider is responsible
// for authentication.)
parsed, err := url.Parse(opts.Authority)
if err != nil {
return Client{}, errors.New("invalid authority")
}
baseOpts = append(baseOpts, base.WithKnownAuthorityHosts([]string{parsed.Hostname()}))
}
base, err := base.New(clientID, opts.Authority, oauth.New(opts.HTTPClient), baseOpts...)
if err != nil {
return Client{}, err
}
Expand Down
78 changes: 78 additions & 0 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,30 @@ import (
"errors"
"fmt"
"io"
"net/http"
"os"
"path/filepath"
"strings"
"testing"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/fake"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
"github.com/golang-jwt/jwt/v4"
)

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
type errorClient struct{}

func (*errorClient) Do(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String())
}

func (*errorClient) CloseIdleConnections() {}

func TestCertFromPEM(t *testing.T) {
f, err := os.Open(filepath.Clean("../testdata/test-cert.pem"))
if err != nil {
Expand Down Expand Up @@ -411,3 +423,69 @@ func TestNewCredFromCertChainError(t *testing.T) {
})
}
}

func TestNewCredFromTokenProvider(t *testing.T) {
expectedToken := "expected token"
called := false
expiresIn := 4200
key := struct{}{}
ctx := context.WithValue(context.Background(), key, true)
cred := NewCredFromTokenProvider(func(c context.Context, tp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
if called {
t.Fatal("expected exactly one token provider invocation")
}
called = true
if v := c.Value(key); v == nil || !v.(bool) {
t.Fatal("callback received unexpected context")
}
if tp.CorrelationID == "" {
t.Fatal("expected CorrelationID")
}
if v := fmt.Sprint(tp.Scopes); v != fmt.Sprint(tokenScope) {
t.Fatalf(`unexpected scopes "%v"`, v)
}
return exported.TokenProviderResult{
AccessToken: expectedToken,
ExpiresInSeconds: expiresIn,
}, nil
})
client, err := New("client-id", cred, WithHTTPClient(&errorClient{}))
if err != nil {
t.Fatal(err)
}
ar, err := client.AcquireTokenByCredential(ctx, tokenScope)
if err != nil {
t.Fatal(err)
}
if !called {
t.Fatal("token provider wasn't invoked")
}
if v := int(time.Until(ar.ExpiresOn).Seconds()); v < expiresIn-2 || v > expiresIn {
t.Fatalf("expected ExpiresOn ~= %d seconds, got %d", expiresIn, v)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
ar, err = client.AcquireTokenSilent(context.Background(), tokenScope)
if err != nil {
t.Fatal(err)
}
if ar.AccessToken != expectedToken {
t.Fatalf(`unexpected token "%s"`, ar.AccessToken)
}
}

func TestNewCredFromTokenProviderError(t *testing.T) {
expectedError := "something went wrong"
cred := NewCredFromTokenProvider(func(ctx context.Context, tpp exported.TokenProviderParameters) (exported.TokenProviderResult, error) {
return exported.TokenProviderResult{}, errors.New(expectedError)
})
client, err := New("client-id", cred)
if err != nil {
t.Fatal(err)
}
_, err = client.AcquireTokenByCredential(context.Background(), tokenScope)
if err == nil || !strings.Contains(err.Error(), expectedError) {
t.Fatalf(`unexpected error "%v"`, err)
}
}
6 changes: 3 additions & 3 deletions apps/confidential/examples_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ package confidential_test

import (
"fmt"
"io/ioutil"
"log"
"os"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
)

func ExampleNewCredFromCert_pem() {
b, err := ioutil.ReadFile("key.pem")
b, err := os.ReadFile("key.pem")
if err != nil {
log.Fatal(err)
}
Expand All @@ -35,7 +35,7 @@ func ExampleNewCredFromCert_pem() {
}

func ExampleNewCredFromCertChain() {
b, err := ioutil.ReadFile("key.pem")
b, err := os.ReadFile("key.pem")
if err != nil {
// TODO: handle error
}
Expand Down
3 changes: 1 addition & 2 deletions apps/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strings"
Expand All @@ -21,7 +20,7 @@ var prettyConf = &pretty.Config{
TrackCycles: true,
Formatter: map[reflect.Type]interface{}{
reflect.TypeOf((*io.Reader)(nil)).Elem(): func(r io.Reader) string {
b, err := ioutil.ReadAll(r)
b, err := io.ReadAll(r)
if err != nil {
return "could not read io.Reader content"
}
Expand Down
13 changes: 11 additions & 2 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ func WithCacheAccessor(ca cache.ExportReplace) Option {
}
}

// WithKnownAuthorityHosts specifies hosts Client shouldn't validate or request metadata for because they're known to the user
func WithKnownAuthorityHosts(hosts []string) Option {
return func(c *Client) {
cp := make([]string, len(hosts))
copy(cp, hosts)
c.AuthParams.KnownAuthorityHosts = cp
}
}

// WithX5C specifies if x5c claim(public key of the certificate) should be sent to STS to enable Subject Name Issuer Authentication.
func WithX5C(sendX5C bool) Option {
return func(c *Client) {
Expand Down Expand Up @@ -230,7 +239,7 @@ func (b Client) AuthCodeURL(ctx context.Context, clientID, redirectURI string, s
func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilentParameters) (AuthResult, error) {
authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and authParams is not a pointer.
authParams.Scopes = silent.Scopes
authParams.HomeaccountID = silent.Account.HomeAccountID
authParams.HomeAccountID = silent.Account.HomeAccountID
authParams.AuthorizationType = silent.AuthorizationType
authParams.UserAssertion = silent.UserAssertion

Expand Down Expand Up @@ -376,7 +385,7 @@ func (b Client) AllAccounts() []shared.Account {
func (b Client) Account(homeAccountID string) shared.Account {
authParams := b.AuthParams // This is a copy, as we dont' have a pointer receiver and .AuthParams is not a pointer.
authParams.AuthorizationType = authority.AccountByID
authParams.HomeaccountID = homeAccountID
authParams.HomeAccountID = homeAccountID
if s, ok := b.manager.(cache.Serializer); ok {
suggestedCacheKey := b.AuthParams.CacheKey(false)
b.cacheAccessor.Replace(s, suggestedCacheKey)
Expand Down
9 changes: 1 addition & 8 deletions apps/internal/base/internal/storage/items_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ package storage

import (
stdJSON "encoding/json"
"io/ioutil"
"os"
"testing"
"time"
Expand Down Expand Up @@ -204,13 +203,7 @@ func TestAppMetaDataMarshal(t *testing.T) {
}

func TestContractUnmarshalJSON(t *testing.T) {
jsonFile, err := os.Open(testFile)
if err != nil {
panic(err)
}
defer jsonFile.Close()

testCache, err := ioutil.ReadAll(jsonFile)
testCache, err := os.ReadFile(testFile)
if err != nil {
panic(err)
}
Expand Down
4 changes: 2 additions & 2 deletions apps/internal/base/internal/storage/partitioned_storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ func (m *PartitionedManager) Read(ctx context.Context, authParameters authority.

// Write writes a token response to the cache and returns the account information the token is stored with.
func (m *PartitionedManager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) {
authParameters.HomeaccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeaccountID
authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeAccountID
environment := authParameters.AuthorityInfo.Host
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
Expand Down
27 changes: 16 additions & 11 deletions apps/internal/base/internal/storage/storage.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,17 +84,22 @@ func isMatchingScopes(scopesOne []string, scopesTwo string) bool {

// Read reads a storage token from the cache if it exists.
func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams, account shared.Account) (TokenResponse, error) {
homeAccountID := authParameters.HomeaccountID
homeAccountID := authParameters.HomeAccountID
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
scopes := authParameters.Scopes

metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo)
if err != nil {
return TokenResponse{}, err
// fetch metadata if and only if the authority isn't explicitly trusted
aliases := authParameters.KnownAuthorityHosts
if len(aliases) == 0 {
metadata, err := m.getMetadataEntry(ctx, authParameters.AuthorityInfo)
if err != nil {
return TokenResponse{}, err
}
aliases = metadata.Aliases
}

accessToken := m.readAccessToken(homeAccountID, metadata.Aliases, realm, clientID, scopes)
accessToken := m.readAccessToken(homeAccountID, aliases, realm, clientID, scopes)

if account.IsZero() {
return TokenResponse{
Expand All @@ -104,22 +109,22 @@ func (m *Manager) Read(ctx context.Context, authParameters authority.AuthParams,
Account: shared.Account{},
}, nil
}
idToken, err := m.readIDToken(homeAccountID, metadata.Aliases, realm, clientID)
idToken, err := m.readIDToken(homeAccountID, aliases, realm, clientID)
if err != nil {
return TokenResponse{}, err
}

AppMetaData, err := m.readAppMetaData(metadata.Aliases, clientID)
AppMetaData, err := m.readAppMetaData(aliases, clientID)
if err != nil {
return TokenResponse{}, err
}
familyID := AppMetaData.FamilyID

refreshToken, err := m.readRefreshToken(homeAccountID, metadata.Aliases, familyID, clientID)
refreshToken, err := m.readRefreshToken(homeAccountID, aliases, familyID, clientID)
if err != nil {
return TokenResponse{}, err
}
account, err = m.readAccount(homeAccountID, metadata.Aliases, realm)
account, err = m.readAccount(homeAccountID, aliases, realm)
if err != nil {
return TokenResponse{}, err
}
Expand All @@ -135,8 +140,8 @@ const scopeSeparator = " "

// Write writes a token response to the cache and returns the account information the token is stored with.
func (m *Manager) Write(authParameters authority.AuthParams, tokenResponse accesstokens.TokenResponse) (shared.Account, error) {
authParameters.HomeaccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeaccountID
authParameters.HomeAccountID = tokenResponse.ClientInfo.HomeAccountID()
homeAccountID := authParameters.HomeAccountID
environment := authParameters.AuthorityInfo.Host
realm := authParameters.AuthorityInfo.Tenant
clientID := authParameters.ClientID
Expand Down
6 changes: 3 additions & 3 deletions apps/internal/base/internal/storage/storage_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package storage
import (
"context"
"errors"
"io/ioutil"
"os"
"reflect"
"sort"
"testing"
Expand Down Expand Up @@ -635,7 +635,7 @@ func TestStorageManagerSerialize(t *testing.T) {

func TestUnmarshal(t *testing.T) {
manager := newForTest(nil)
b, err := ioutil.ReadFile(testFile)
b, err := os.ReadFile(testFile)
if err != nil {
panic(err)
}
Expand Down Expand Up @@ -747,7 +747,7 @@ func TestRead(t *testing.T) {
Tenant: "realm",
}
authParameters := authority.AuthParams{
HomeaccountID: "hid",
HomeAccountID: "hid",
AuthorityInfo: authInfo,
ClientID: "cid",
Scopes: []string{"openid", "profile"},
Expand Down
Loading

0 comments on commit 031858c

Please sign in to comment.