Skip to content
This repository has been archived by the owner on Oct 23, 2024. It is now read-only.

Commit

Permalink
Merge pull request #111 from planetscale/dbussink/improve-sql-proxy-cmd
Browse files Browse the repository at this point in the history
Improve the sql-proxy command line tool
  • Loading branch information
dbussink authored Aug 19, 2021
2 parents 918b93f + 256dd5c commit 9109ec0
Showing 1 changed file with 28 additions and 47 deletions.
75 changes: 28 additions & 47 deletions cmd/sql-proxy-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,13 @@ import (
"crypto/rand"
"crypto/tls"
"crypto/x509"
"encoding/pem"
"errors"
"flag"
"fmt"
"io/ioutil"
"net"
"os"
"os/signal"
"strconv"
"strings"

ps "github.com/planetscale/planetscale-go/planetscale"
Expand All @@ -38,7 +37,7 @@ func realMain() error {
host := flag.String("host", "127.0.0.1", "Local host to bind and listen for connections")
port := flag.String("port", "3306", "Local port to bind and listen for connections")

remoteAddr := flag.String("remote-addr", "", "MySQL remote network address")
remoteHost := flag.String("remote-host", "", "MySQL remote host")
remotePort := flag.Int("remote-port", 3307, "MySQL remote port")

orgName := flag.String("org", os.Getenv("PLANETSCALE_ORG"),
Expand All @@ -54,7 +53,6 @@ func realMain() error {

showVersion := flag.Bool("version", false, "Show version of the proxy")

caPath := flag.String("ca", "", "MySQL CA Cert path")
clientCertPath := flag.String("cert", "", "MySQL Client Cert path")
clientKeyPath := flag.String("key", "", "MySQL Client Key path")

Expand All @@ -69,30 +67,44 @@ func realMain() error {
return errors.New("--token and --service-token/--service-token-name cannot be set at the same time")
}

if *orgName == "" || *dbName == "" || *branchName == "" {
return errors.New("--org, --database or --branch is not set")
}

var certSource proxy.CertSource
var err error
var instance string

certSource, err = newRemoteCertSource(*token, *serviceToken, *serviceTokenName)
if err != nil {
return err
if *token != "" || (*serviceToken != "" && *serviceTokenName != "") {
if *orgName == "" || *dbName == "" || *branchName == "" {
return errors.New("--org, --database or --branch is not set with a token")
}
instance = fmt.Sprintf("%s/%s/%s", *orgName, *dbName, *branchName)

certSource, err = newRemoteCertSource(*token, *serviceToken, *serviceTokenName)
if err != nil {
return err
}
}

if *caPath != "" && *clientCertPath != "" && *clientKeyPath != "" {
certSource, err = newLocalCertSource(*caPath, *clientCertPath, *clientKeyPath, *remoteAddr, *remotePort)
if *remoteHost != "" && *clientCertPath != "" && *clientKeyPath != "" {
localCertSource, err := newLocalCertSource(*clientCertPath, *clientKeyPath, *remoteHost, *remotePort)
if err != nil {
return err
}
certSource = localCertSource
cert, err := x509.ParseCertificate(localCertSource.cert.Certificate[0])
if err != nil {
return err
}
instance = cert.Subject.String()
}

if certSource == nil {
return errors.New("no configuration found, need either a token and org / datbase / branch parameters or separate specified certificate source and remote host")
}

p, err := proxy.NewClient(proxy.Options{
CertSource: certSource,
LocalAddr: net.JoinHostPort(*host, *port),
RemoteAddr: *remoteAddr,
Instance: fmt.Sprintf("%s/%s/%s", *orgName, *dbName, *branchName),
RemoteAddr: net.JoinHostPort(*remoteHost, strconv.Itoa(*remotePort)),
Instance: instance,
})
if err != nil {
return fmt.Errorf("couldn't create proxy client: %s", err)
Expand Down Expand Up @@ -152,25 +164,14 @@ func (r *remoteCertSource) Cert(ctx context.Context, org, db, branch string) (*p
}, nil
}

func newLocalCertSource(caPath, certPath, keyPath, remoteAddr string, remotePort int) (*localCertSource, error) {
pem, err := ioutil.ReadFile(caPath)
if err != nil {
return nil, err
}

caCerts, err := parseCaCerts(pem)
if err != nil {
return nil, err
}

func newLocalCertSource(certPath, keyPath, remoteAddr string, remotePort int) (*localCertSource, error) {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
if err != nil {
return nil, err
}

return &localCertSource{
cert: cert,
caCerts: caCerts,
remoteAddr: remoteAddr,
remotePort: remotePort,
}, nil
Expand All @@ -179,7 +180,6 @@ func newLocalCertSource(caPath, certPath, keyPath, remoteAddr string, remotePort

type localCertSource struct {
cert tls.Certificate
caCerts []*x509.Certificate
remoteAddr string
remotePort int
}
Expand All @@ -194,25 +194,6 @@ func (c *localCertSource) Cert(ctx context.Context, org, db, branch string) (*pr
}, nil
}

func parseCaCerts(pemCert []byte) ([]*x509.Certificate, error) {
var certs []*x509.Certificate

for {
var certBlock *pem.Block
certBlock, pemCert = pem.Decode(pemCert)
if certBlock == nil {
break
}
cert, err := x509.ParseCertificate(certBlock.Bytes)
if err != nil {
return nil, err
}

certs = append(certs, cert)
}
return certs, nil
}

// printVersion formats a version string with the given information.
func printVersion(ver, commit, buildDate string) {
if ver == "" && buildDate == "" && commit == "" {
Expand Down

0 comments on commit 9109ec0

Please sign in to comment.