Skip to content

Commit

Permalink
feat: ensure that ClientCert, ClientKey of custom_tls are optional an…
Browse files Browse the repository at this point in the history
…d handled correctly + default to system cert pool (#182)

Co-authored-by: Dave Heward <david.heward@unmind.com>
  • Loading branch information
davehewy and david-heward-unmind authored Nov 13, 2024
1 parent 71dfdd0 commit f126bc3
Showing 1 changed file with 35 additions and 23 deletions.
58 changes: 35 additions & 23 deletions mysql/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,11 +147,13 @@ func Provider() *schema.Provider {
},
"client_cert": {
Type: schema.TypeString,
Required: true,
Default: "",
Optional: true,
},
"client_key": {
Type: schema.TypeString,
Required: true,
Default: "",
Optional: true,
},
},
},
Expand Down Expand Up @@ -333,6 +335,7 @@ func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}
customTLSMap := d.Get("custom_tls").([]interface{})
if len(customTLSMap) > 0 {
var customTLS CustomTLS
var rootCertPool *x509.CertPool
customMap := customTLSMap[0].(map[string]interface{})
customTLSJson, err := json.Marshal(customMap)
if err != nil {
Expand All @@ -345,36 +348,45 @@ func providerConfigure(ctx context.Context, d *schema.ResourceData) (interface{}
}

var pem []byte
rootCertPool := x509.NewCertPool()
if strings.HasPrefix(customTLS.CACert, "-----BEGIN") {
pem = []byte(customTLS.CACert)
if customTLS.CACert != "" {
rootCertPool := x509.NewCertPool()
if strings.HasPrefix(customTLS.CACert, "-----BEGIN") {
pem = []byte(customTLS.CACert)
} else {
pem, err = os.ReadFile(customTLS.CACert)
if err != nil {
return nil, diag.Errorf("failed to read CA cert: %v", err)
}
}
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, diag.Errorf("failed to append pem: %v", pem)
}
} else {
pem, err = os.ReadFile(customTLS.CACert)
// Use system cert pool as fallback
rootCertPool, err = x509.SystemCertPool()
if err != nil {
return nil, diag.Errorf("failed to read CA cert: %v", err)
return nil, diag.Errorf("failed to get system cert pool: %v", err)
}
}

if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
return nil, diag.Errorf("failed to append pem: %v", pem)
tlsConfigStruct = &tls.Config{
RootCAs: rootCertPool,
}

clientCert := make([]tls.Certificate, 0, 1)
var certs tls.Certificate
if strings.HasPrefix(customTLS.ClientCert, "-----BEGIN") {
certs, err = tls.X509KeyPair([]byte(customTLS.ClientCert), []byte(customTLS.ClientKey))
} else {
certs, err = tls.LoadX509KeyPair(customTLS.ClientCert, customTLS.ClientKey)
}
if err != nil {
return nil, diag.Errorf("error loading keypair: %v", err)
}
var cert tls.Certificate

clientCert = append(clientCert, certs)
tlsConfigStruct = &tls.Config{
RootCAs: rootCertPool,
Certificates: clientCert,
if customTLS.ClientCert != "" && customTLS.ClientKey != "" {
if strings.HasPrefix(customTLS.ClientCert, "-----BEGIN") {
cert, err = tls.X509KeyPair([]byte(customTLS.ClientCert), []byte(customTLS.ClientKey))
} else {
cert, err = tls.LoadX509KeyPair(customTLS.ClientCert, customTLS.ClientKey)
}
if err != nil {
return nil, diag.Errorf("error loading keypair: %v", err)
}
tlsConfigStruct.Certificates = []tls.Certificate{cert}
}

err = mysql.RegisterTLSConfig(customTLS.ConfigKey, tlsConfigStruct)
if err != nil {
return nil, diag.Errorf("failed registering TLS config: %v", err)
Expand Down

0 comments on commit f126bc3

Please sign in to comment.