Skip to content

Commit

Permalink
feat(auth): auth library can talk to S2A over mTLS (#10634)
Browse files Browse the repository at this point in the history
  • Loading branch information
xmenxk authored Aug 15, 2024
1 parent b02a9a1 commit 5250a13
Show file tree
Hide file tree
Showing 6 changed files with 296 additions and 156 deletions.
82 changes: 70 additions & 12 deletions auth/internal/transport/cba.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ package transport
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"log"
"net"
"net/http"
"net/url"
Expand All @@ -44,10 +46,12 @@ const (
googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"

universeDomainPlaceholder = "UNIVERSE_DOMAIN"

mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
mtlsMDSKey = "/run/google-mds-mtls/client.key"
)

var (
mdsMTLSAutoConfigSource mtlsConfigSource
errUniverseNotSupportedMTLS = errors.New("mTLS is not supported in any universe other than googleapis.com")
)

Expand Down Expand Up @@ -120,7 +124,20 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
defaultTransportCreds := credentials.NewTLS(&tls.Config{
GetClientCertificate: config.clientCertSource,
})
if config.s2aAddress == "" {

var s2aAddr string
var transportCredsForS2A credentials.TransportCredentials

if config.mtlsS2AAddress != "" {
s2aAddr = config.mtlsS2AAddress
transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
if err != nil {
log.Printf("Loading MTLS MDS credentials failed: %v", err)
return defaultTransportCreds, config.endpoint, nil
}
} else if config.s2aAddress != "" {
s2aAddr = config.s2aAddress
} else {
return defaultTransportCreds, config.endpoint, nil
}

Expand All @@ -133,8 +150,9 @@ func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCrede
}

s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
S2AAddress: config.s2aAddress,
FallbackOpts: fallbackOpts,
S2AAddress: s2aAddr,
TransportCreds: transportCredsForS2A,
FallbackOpts: fallbackOpts,
})
if err != nil {
// Use default if we cannot initialize S2A client transport credentials.
Expand All @@ -151,7 +169,19 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context,
return nil, nil, err
}

if config.s2aAddress == "" {
var s2aAddr string
var transportCredsForS2A credentials.TransportCredentials

if config.mtlsS2AAddress != "" {
s2aAddr = config.mtlsS2AAddress
transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
if err != nil {
log.Printf("Loading MTLS MDS credentials failed: %v", err)
return config.clientCertSource, nil, nil
}
} else if config.s2aAddress != "" {
s2aAddr = config.s2aAddress
} else {
return config.clientCertSource, nil, nil
}

Expand All @@ -169,12 +199,38 @@ func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context,
}

dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
S2AAddress: config.s2aAddress,
FallbackOpts: fallbackOpts,
S2AAddress: s2aAddr,
TransportCreds: transportCredsForS2A,
FallbackOpts: fallbackOpts,
})
return nil, dialTLSContextFunc, nil
}

func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
rootPEM, err := os.ReadFile(mtlsMDSRootFile)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
ok := caCertPool.AppendCertsFromPEM(rootPEM)
if !ok {
return nil, errors.New("failed to load MTLS MDS root certificate")
}
// The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
// followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
// tls.X509KeyPair function as both the certificate chain and private key arguments.
cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
if err != nil {
return nil, err
}
tlsConfig := tls.Config{
RootCAs: caCertPool,
Certificates: []tls.Certificate{cert},
MinVersion: tls.VersionTLS13,
}
return credentials.NewTLS(&tlsConfig), nil
}

func getTransportConfig(opts *Options) (*transportConfig, error) {
clientCertSource, err := GetClientCertificateProvider(opts)
if err != nil {
Expand All @@ -196,17 +252,17 @@ func getTransportConfig(opts *Options) (*transportConfig, error) {
return nil, errUniverseNotSupportedMTLS
}

s2aMTLSEndpoint := opts.DefaultMTLSEndpoint

s2aAddress := GetS2AAddress()
if s2aAddress == "" {
mtlsS2AAddress := GetMTLSS2AAddress()
if s2aAddress == "" && mtlsS2AAddress == "" {
return &defaultTransportConfig, nil
}
return &transportConfig{
clientCertSource: clientCertSource,
endpoint: endpoint,
s2aAddress: s2aAddress,
s2aMTLSEndpoint: s2aMTLSEndpoint,
mtlsS2AAddress: mtlsS2AAddress,
s2aMTLSEndpoint: opts.DefaultMTLSEndpoint,
}, nil
}

Expand Down Expand Up @@ -241,8 +297,10 @@ type transportConfig struct {
clientCertSource cert.Provider
// The corresponding endpoint to use based on client certificate source.
endpoint string
// The S2A address if it can be used, otherwise an empty string.
// The plaintext S2A address if it can be used, otherwise an empty string.
s2aAddress string
// The MTLS S2A address if it can be used, otherwise an empty string.
mtlsS2AAddress string
// The MTLS endpoint to use with S2A.
s2aMTLSEndpoint string
}
Expand Down
108 changes: 88 additions & 20 deletions auth/internal/transport/cba_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ import (
"fmt"
"net/http"
"testing"
"time"

"cloud.google.com/go/auth/internal"
"cloud.google.com/go/auth/internal/transport/cert"
Expand Down Expand Up @@ -50,6 +49,20 @@ var (
return string(configStr), nil
}

validConfigRespMTLSS2A = func() (string, error) {
validConfig := mtlsConfig{
S2A: &s2aAddresses{
PlaintextAddress: "",
MTLSAddress: testMTLSS2AAddr,
},
}
configStr, err := json.Marshal(validConfig)
if err != nil {
return "", err
}
return string(configStr), nil
}

errorConfigResp = func() (string, error) {
return "", fmt.Errorf("error getting config")
}
Expand Down Expand Up @@ -250,7 +263,7 @@ func TestGetEndpointWithClientCertSource(t *testing.T) {
}
}

func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
func TestGetGRPCTransportConfigAndEndpoint_S2A(t *testing.T) {
testCases := []struct {
name string
opts *Options
Expand Down Expand Up @@ -324,11 +337,21 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
validConfigResp,
testRegularEndpoint,
},
{
"no client cert, MTLS S2A address not empty, no MTLS MDS cert",
&Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
validConfigRespMTLSS2A,
testRegularEndpoint,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2ARespFn
mtlsConfiguration, _ = queryConfig()
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -338,17 +361,15 @@ func TestGetGRPCTransportConfigAndEndpoint(t *testing.T) {
if tc.want != endpoint {
t.Fatalf("want endpoint: %s, got %s", tc.want, endpoint)
}
// Let the cached MTLS config expire at the end of each test case.
time.Sleep(2 * time.Millisecond)
})
}
}

func TestGetHTTPTransportConfig_S2a(t *testing.T) {
func TestGetHTTPTransportConfig_S2A(t *testing.T) {
testCases := []struct {
name string
opts *Options
s2aFn func() (string, error)
s2ARespFn func() (string, error)
want string
isDialFnNil bool
}{
Expand All @@ -359,7 +380,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
ClientCertProvider: fakeClientCertSource,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
isDialFnNil: true,
},
Expand All @@ -369,16 +390,16 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
want: testMTLSEndpoint,
s2ARespFn: validConfigResp,
want: testMTLSEndpoint,
},
{
name: "no client cert, S2A address empty",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: invalidConfigResp,
s2ARespFn: invalidConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
Expand All @@ -389,7 +410,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
Endpoint: testOverrideEndpoint,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testOverrideEndpoint,
isDialFnNil: true,
},
Expand All @@ -399,7 +420,7 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultMTLSEndpoint: "",
DefaultEndpointTemplate: testEndpointTemplate,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
Expand All @@ -410,15 +431,26 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
DefaultEndpointTemplate: testEndpointTemplate,
Client: http.DefaultClient,
},
s2aFn: validConfigResp,
s2ARespFn: validConfigResp,
want: testRegularEndpoint,
isDialFnNil: true,
},
{
name: "no client cert, MTLS S2A address not empty, no MTLS MDS cert",
opts: &Options{
DefaultMTLSEndpoint: testMTLSEndpoint,
DefaultEndpointTemplate: testEndpointTemplate,
},
s2ARespFn: validConfigRespMTLSS2A,
want: testRegularEndpoint,
isDialFnNil: true,
},
}
defer setupTest(t)()
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
httpGetMetadataMTLSConfig = tc.s2aFn
httpGetMetadataMTLSConfig = tc.s2ARespFn
mtlsConfiguration, _ = queryConfig()
if tc.opts.ClientCertProvider != nil {
t.Setenv(googleAPIUseCertSource, "true")
} else {
Expand All @@ -431,22 +463,58 @@ func TestGetHTTPTransportConfig_S2a(t *testing.T) {
if want, got := tc.isDialFnNil, dialFunc == nil; want != got {
t.Errorf("expecting returned dialFunc is nil: [%v], got [%v]", tc.isDialFnNil, got)
}
// Let MTLS config expire at end of each test case.
time.Sleep(2 * time.Millisecond)
})
}
}

func TestLoadMTLSMDSTransportCreds(t *testing.T) {
testCases := []struct {
name string
rootFile string
keyFile string
wantErr bool
}{
{
name: "missing root file",
rootFile: "",
keyFile: "./testdata/mtls_mds_key.pem",
wantErr: true,
},
{
name: "missing key file",
rootFile: "./testdata/mtls_mds_root.pem",
keyFile: "",
wantErr: true,
},
{
name: "missing both root and key files",
rootFile: "",
keyFile: "",
wantErr: true,
},
{
name: "load credentials success",
rootFile: "./testdata/mtls_mds_root.pem",
keyFile: "./testdata/mtls_mds_key.pem",
wantErr: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
_, err := loadMTLSMDSTransportCreds(tc.rootFile, tc.keyFile)
if gotErr := err != nil; gotErr != tc.wantErr {
t.Errorf("loadMTLSMDSTransportCreds(%q, %q) got error: %v, want error: %v", tc.rootFile, tc.keyFile, gotErr, tc.wantErr)
}
})
}
}

func setupTest(t *testing.T) func() {
oldHTTPGet := httpGetMetadataMTLSConfig
oldExpiry := configExpiry

configExpiry = time.Millisecond
t.Setenv(googleAPIUseS2AEnv, "true")

return func() {
httpGetMetadataMTLSConfig = oldHTTPGet
configExpiry = oldExpiry
}
}

Expand Down
Loading

0 comments on commit 5250a13

Please sign in to comment.