Skip to content

Commit

Permalink
feat: use instance IP as SAN (#289)
Browse files Browse the repository at this point in the history
Co-authored-by: Eno Compton <enocom@google.com>
  • Loading branch information
cthumuluru and enocom authored May 1, 2023
1 parent d66557b commit 30d9740
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 95 deletions.
11 changes: 1 addition & 10 deletions dialer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,10 @@ func TestDialWithAdminAPIErrors(t *testing.T) {
}
}

func TestDialWithConfigurationErrors(t *testing.T) {
func TestDialWithUnavailableServerErrors(t *testing.T) {
ctx := context.Background()
inst := mock.NewFakeInstance(
"my-project", "my-region", "my-cluster", "my-instance",
mock.WithServerName("not-the-server-youre-looking-for"),
)
// Don't use the cleanup function. Because this test is about error
// cases, API requests (started in two separate goroutines) will
Expand All @@ -147,14 +146,6 @@ func TestDialWithConfigurationErrors(t *testing.T) {
if !errors.As(err, &wantErr2) {
t.Fatalf("when server proxy socket is unavailable, want = %T, got = %v", wantErr2, err)
}

stop := mock.StartServerProxy(t, inst)
defer stop()

_, err = d.Dial(ctx, "/projects/my-project/locations/my-region/clusters/my-cluster/instances/my-instance")
if !errors.As(err, &wantErr2) {
t.Fatalf("when TLS handshake fails, want = %T, got = %v", wantErr2, err)
}
}

func TestDialerWithCustomDialFunc(t *testing.T) {
Expand Down
143 changes: 58 additions & 85 deletions internal/alloydb/refresh.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"strings"
"time"

alloydbadmin "cloud.google.com/go/alloydb/apiv1beta"
Expand Down Expand Up @@ -78,7 +79,7 @@ func fetchEphemeralCert(
cl *alloydbadmin.AlloyDBAdminClient,
inst InstanceURI,
key *rsa.PrivateKey,
) (cc certChain, err error) {
) (cc *certs, err error) {
var end trace.EndSpanFunc
ctx, end = trace.StartSpan(ctx, "cloud.google.com/go/alloydbconn/internal.FetchEphemeralCert")
defer func() { end(err) }()
Expand All @@ -97,12 +98,12 @@ func fetchEphemeralCert(
}
csrBytes, err := x509.CreateCertificateRequest(rand.Reader, &tmpl, key)
if err != nil {
return certChain{}, err
return nil, err
}
buf := &bytes.Buffer{}
err = pem.Encode(buf, &pem.Block{Type: "CERTIFICATE REQUEST", Bytes: csrBytes})
if err != nil {
return certChain{}, err
return nil, err
}
req := &alloydbpb.GenerateClientCertificateRequest{
Parent: fmt.Sprintf(
Expand All @@ -113,102 +114,72 @@ func fetchEphemeralCert(
}
resp, err := cl.GenerateClientCertificate(ctx, req)
if err != nil {
return certChain{}, errtype.NewRefreshError(
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
// There should always be two certs in the chain. If this fails, the API has
// broken its contract with the client.
if len(resp.PemCertificateChain) != 2 {
return certChain{}, errtype.NewRefreshError(
"missing instance and root certificates",

certChainPEM := append([]string{resp.PemCertificate}, resp.PemCertificateChain...)
certPEMBlock := []byte(strings.Join(certChainPEM, "\n"))
keyPEMBlock := &pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
}

cert, err := tls.X509KeyPair(certPEMBlock, pem.EncodeToMemory(keyPEMBlock))
if err != nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
rc, err := parseCert(resp.PemCertificateChain[1]) // root cert
if err != nil {
return certChain{}, errtype.NewRefreshError(
"failed to parse root cert",

// TODO(fixme) Take the root cert from the cert chain for now.
caCertPEMBlock, _ := pem.Decode([]byte(certChainPEM[2]))
if caCertPEMBlock == nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
errors.New("no PEM data found in the ca cert"),
)
}
ic, err := parseCert(resp.PemCertificateChain[0]) // intermediate cert
caCert, err := x509.ParseCertificate(caCertPEMBlock.Bytes)
if err != nil {
return certChain{}, errtype.NewRefreshError(
"failed to parse intermediate cert",
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}
c, err := parseCert(resp.PemCertificate) // client cert

// Extract expiry
clientCertPEMBlock, _ := pem.Decode([]byte(certChainPEM[0]))
if clientCertPEMBlock == nil {
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
errors.New("no PEM data found in the client cert"),
)
}
clientCert, err := x509.ParseCertificate(clientCertPEMBlock.Bytes)
if err != nil {
return certChain{}, errtype.NewRefreshError(
"failed to parse client cert",
return nil, errtype.NewRefreshError(
"create ephemeral cert failed",
inst.String(),
err,
)
}

return certChain{
root: rc,
intermediate: ic,
client: c,
return &certs{
certChain: cert,
caCert: caCert,
expiry: clientCert.NotAfter,
}, nil
}

// createTLSConfig returns a *tls.Config for connecting securely to the AlloyDB
// instance.
func createTLSConfig(inst InstanceURI, cc certChain, info connectInfo, k *rsa.PrivateKey) *tls.Config {
certs := x509.NewCertPool()
certs.AddCert(cc.root)

return &tls.Config{
InsecureSkipVerify: true,
VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
var parsed []*x509.Certificate
for _, r := range rawCerts {
c, err := x509.ParseCertificate(r)
if err != nil {
return errtype.NewDialError("failed to parse X.509 certificate", inst.String(), err)
}
parsed = append(parsed, c)
}
server := parsed[0]
inter := x509.NewCertPool()
for i := 1; i < len(parsed); i++ {
inter.AddCert(parsed[i])
}

opts := x509.VerifyOptions{Roots: certs, Intermediates: inter}
if _, err := server.Verify(opts); err != nil {
return errtype.NewDialError("failed to verify certificate", inst.String(), err)
}

serverName := fmt.Sprintf("%v.server.alloydb", info.uid)
if server.Subject.CommonName != serverName {
return errtype.NewDialError(
fmt.Sprintf("certificate had CN %q, expected %q",
server.Subject.CommonName, serverName),
inst.String(),
nil,
)
}
return nil
},
Certificates: []tls.Certificate{{
Certificate: [][]byte{cc.client.Raw, cc.intermediate.Raw},
PrivateKey: k,
Leaf: cc.client,
}},
RootCAs: certs,
MinVersion: tls.VersionTLS13,
}
}

// newRefresher creates a Refresher.
func newRefresher(
client *alloydbadmin.AlloyDBAdminClient,
Expand Down Expand Up @@ -236,10 +207,10 @@ type refreshResult struct {
expiry time.Time
}

type certChain struct {
root *x509.Certificate
intermediate *x509.Certificate
client *x509.Certificate
type certs struct {
certChain tls.Certificate // TLS client certificate
caCert *x509.Certificate // CA certificate
expiry time.Time
}

func (r refresher) performRefresh(ctx context.Context, cn InstanceURI, k *rsa.PrivateKey) (res refreshResult, err error) {
Expand All @@ -264,7 +235,7 @@ func (r refresher) performRefresh(ctx context.Context, cn InstanceURI, k *rsa.Pr
}()

type certRes struct {
cc certChain
cc *certs
err error
}
certCh := make(chan certRes, 1)
Expand All @@ -285,7 +256,7 @@ func (r refresher) performRefresh(ctx context.Context, cn InstanceURI, k *rsa.Pr
return refreshResult{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

var cc certChain
var cc *certs
select {
case r := <-certCh:
if r.err != nil {
Expand All @@ -296,11 +267,13 @@ func (r refresher) performRefresh(ctx context.Context, cn InstanceURI, k *rsa.Pr
return refreshResult{}, fmt.Errorf("refresh failed: %w", ctx.Err())
}

c := createTLSConfig(cn, cc, info, k)
var expiry time.Time
// This should never not be the case, but we check to avoid a potential nil-pointer
if len(c.Certificates) > 0 {
expiry = c.Certificates[0].Leaf.NotAfter
caCerts := x509.NewCertPool()
caCerts.AddCert(cc.caCert)
c := &tls.Config{
Certificates: []tls.Certificate{cc.certChain},
RootCAs: caCerts,
ServerName: info.ipAddr,
}
return refreshResult{instanceIPAddr: info.ipAddr, conf: c, expiry: expiry}, nil

return refreshResult{instanceIPAddr: info.ipAddr, conf: c, expiry: cc.expiry}, nil
}
1 change: 1 addition & 0 deletions internal/mock/alloydb.go
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ func NewFakeInstance(proj, reg, clust, name string, opts ...Option) FakeAlloyDBI
IsCA: true,
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IPAddresses: []net.IP{net.IPv4(127, 0, 0, 1)},
}
signedServer, err := x509.CreateCertificate(
rand.Reader, serverTemplate, rootCert, &serverKey.PublicKey, rootCAKey)
Expand Down

0 comments on commit 30d9740

Please sign in to comment.