diff --git a/cmd/root.go b/cmd/root.go index 5486949c53..d2812e7458 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -128,6 +128,8 @@ without having to manage any client SSL certificates.`, "Address on which to bind AlloyDB instance listeners.") cmd.PersistentFlags().IntVarP(&c.conf.Port, "port", "p", 5432, "Initial port to use for listeners. Subsequent listeners increment from this value.") + cmd.PersistentFlags().StringVarP(&c.conf.UnixSocket, "unix-socket", "u", "", + `Enables Unix sockets for all listeners using the provided directory.`) c.Command = cmd return c @@ -138,6 +140,15 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { if len(args) == 0 { return newBadCommandError("missing instance uri (e.g., /projects/$PROJECTS/locations/$LOCTION/clusters/$CLUSTER/instances/$INSTANCES)") } + userHasSet := func(f string) bool { + return cmd.PersistentFlags().Lookup(f).Changed + } + if userHasSet("address") && userHasSet("unix-socket") { + return newBadCommandError("cannot specify --unix-socket and --address together") + } + if userHasSet("port") && userHasSet("unix-socket") { + return newBadCommandError("cannot specify --unix-socket and --port together") + } // First, validate global config. if ip := net.ParseIP(conf.Addr); ip == nil { return newBadCommandError(fmt.Sprintf("not a valid IP address: %q", conf.Addr)) @@ -171,7 +182,18 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { return newBadCommandError(fmt.Sprintf("could not parse query: %q", res[1])) } - if a, ok := q["address"]; ok { + a, aok := q["address"] + p, pok := q["port"] + u, uok := q["unix-socket"] + + if aok && uok { + return newBadCommandError("cannot specify both address and unix-socket query params") + } + if pok && uok { + return newBadCommandError("cannot specify both port and unix-socket query params") + } + + if aok { if len(a) != 1 { return newBadCommandError(fmt.Sprintf("address query param should be only one value: %q", a)) } @@ -184,7 +206,7 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { ic.Addr = a[0] } - if p, ok := q["port"]; ok { + if pok { if len(p) != 1 { return newBadCommandError(fmt.Sprintf("port query param should be only one value: %q", a)) } @@ -197,6 +219,14 @@ func parseConfig(cmd *cobra.Command, conf *proxy.Config, args []string) error { } ic.Port = pp } + + if uok { + if len(u) != 1 { + return newBadCommandError(fmt.Sprintf("unix query param should be only one value: %q", a)) + } + ic.UnixSocket = u[0] + + } } ics = append(ics, ic) } diff --git a/cmd/root_test.go b/cmd/root_test.go index 6aea877464..373606c4c8 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -137,6 +137,29 @@ func TestNewCommandArguments(t *testing.T) { CredentialsFile: "/path/to/file", }), }, + { + desc: "using the unix socket flag", + args: []string{"--unix-socket", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"}, + want: withDefaults(&proxy.Config{ + UnixSocket: "/path/to/dir/", + }), + }, + { + desc: "using the (short) unix socket flag", + args: []string{"-u", "/path/to/dir/", "/projects/proj/locations/region/clusters/clust/instances/inst"}, + want: withDefaults(&proxy.Config{ + UnixSocket: "/path/to/dir/", + }), + }, + { + desc: "using the unix socket query param", + args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path/to/dir/"}, + want: withDefaults(&proxy.Config{ + Instances: []proxy.InstanceConnConfig{{ + UnixSocket: "/path/to/dir/", + }}, + }), + }, } for _, tc := range tcs { @@ -210,6 +233,26 @@ func TestNewCommandWithErrors(t *testing.T) { "--token", "my-token", "--credentials-file", "/path/to/file", "/projects/proj/locations/region/clusters/clust/instances/inst"}, }, + { + desc: "when the unix socket query param contains multiple values", + args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/one&unix-socket=/two"}, + }, + { + desc: "using the unix socket flag with addr", + args: []string{"-u", "/path/to/dir/", "-a", "127.0.0.1", "/projects/proj/locations/region/clusters/clust/instances/inst"}, + }, + { + desc: "using the unix socket flag with port", + args: []string{"-u", "/path/to/dir/", "-p", "5432", "/projects/proj/locations/region/clusters/clust/instances/inst"}, + }, + { + desc: "using the unix socket and addr query params", + args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&address=127.0.0.1"}, + }, + { + desc: "using the unix socket and port query params", + args: []string{"/projects/proj/locations/region/clusters/clust/instances/inst?unix-socket=/path&port=5000"}, + }, } for _, tc := range tcs { diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go index 577116de8e..5e05872bbb 100644 --- a/internal/proxy/proxy.go +++ b/internal/proxy/proxy.go @@ -19,6 +19,10 @@ import ( "fmt" "io" "net" + "os" + "path/filepath" + "regexp" + "strings" "sync" "time" @@ -37,6 +41,10 @@ type InstanceConnConfig struct { Addr string // Port is the port on which to bind a listener for the instance. Port int + // UnixSocket is the directory where a Unix socket will be created, + // connected to the Cloud SQL instance. If set, takes precedence over Addr + // and Port. + UnixSocket string } // Config contains all the configuration provided by the caller. @@ -54,6 +62,10 @@ type Config struct { // increments from this value. Port int + // UnixSocket is the directory where Unix sockets will be created, + // connected to any Instances. If set, takes precedence over Addr and Port. + UnixSocket string + // Instances are configuration for individual instances. Instance // configuration takes precedence over global configuration. Instances []InstanceConnConfig @@ -95,6 +107,28 @@ func (c *portConfig) nextPort() int { return p } +var ( + // Instance URI is in the format: + // '/projects//locations//clusters//instances/' + // Additionally, we have to support legacy "domain-scoped" projects (e.g. "google.com:PROJECT") + instURIRegex = regexp.MustCompile("projects/([^:]+(:[^:]+)?)/locations/([^:]+)/clusters/([^:]+)/instances/([^:]+)") +) + +// UnixSocketDir returns a shorted instance connection name to prevent exceeding +// the Unix socket length. +func UnixSocketDir(dir, inst string) (string, error) { + m := instURIRegex.FindSubmatch([]byte(inst)) + if m == nil { + return "", fmt.Errorf("invalid instance name: %v", inst) + } + project := string(m[1]) + region := string(m[3]) + cluster := string(m[4]) + name := string(m[5]) + shortName := strings.Join([]string{project, region, cluster, name}, ".") + return filepath.Join(dir, shortName), nil +} + // Client represents the state of the current instantiation of the proxy. type Client struct { cmd *cobra.Command @@ -106,31 +140,79 @@ type Client struct { // NewClient completes the initial setup required to get the proxy to a "steady" state. func NewClient(ctx context.Context, d alloydb.Dialer, cmd *cobra.Command, conf *Config) (*Client, error) { - var mnts []*socketMount pc := newPortConfig(conf.Port) + var mnts []*socketMount for _, inst := range conf.Instances { - m := &socketMount{inst: inst.Name} - a := conf.Addr - if inst.Addr != "" { - a = inst.Addr - } - var np int - switch { - case inst.Port != 0: - np = inst.Port - default: // use next increment from conf.Port - np = pc.nextPort() + var ( + // network is one of "tcp" or "unix" + network string + // address is either a TCP host port, or a Unix socket + address string + ) + // IF + // a global Unix socket directory is NOT set AND + // an instance-level Unix socket is NOT set + // (e.g., I didn't set a Unix socket globally or for this instance) + // OR + // an instance-level TCP address or port IS set + // (e.g., I'm overriding any global settings to use TCP for this + // instance) + // use a TCP listener. + // Otherwise, use a Unix socket. + if (conf.UnixSocket == "" && inst.UnixSocket == "") || + (inst.Addr != "" || inst.Port != 0) { + network = "tcp" + + a := conf.Addr + if inst.Addr != "" { + a = inst.Addr + } + + var np int + switch { + case inst.Port != 0: + np = inst.Port + case conf.Port != 0: + np = pc.nextPort() + default: + np = pc.nextPort() + } + + address = net.JoinHostPort(a, fmt.Sprint(np)) + } else { + network = "unix" + + dir := conf.UnixSocket + if dir == "" { + dir = inst.UnixSocket + } + ud, err := UnixSocketDir(dir, inst.Name) + if err != nil { + return nil, err + } + // Create the parent directory that will hold the socket. + if _, err := os.Stat(ud); err != nil { + if err = os.Mkdir(ud, 0777); err != nil { + return nil, err + } + } + // use the Postgres-specific socket name + address = filepath.Join(ud, ".s.PGSQL.5432") } - addr, err := m.listen(ctx, "tcp", net.JoinHostPort(a, fmt.Sprint(np))) + + m := &socketMount{inst: inst.Name} + addr, err := m.listen(ctx, network, address) if err != nil { for _, m := range mnts { m.close() } return nil, fmt.Errorf("[%v] Unable to mount socket: %v", inst.Name, err) } + cmd.Printf("[%s] Listening on %s\n", inst.Name, addr.String()) mnts = append(mnts, m) } + return &Client{mnts: mnts, cmd: cmd, dialer: d}, nil } @@ -210,9 +292,9 @@ type socketMount struct { } // listen causes a socketMount to create a Listener at the specified network address. -func (s *socketMount) listen(ctx context.Context, network string, host string) (net.Addr, error) { +func (s *socketMount) listen(ctx context.Context, network string, address string) (net.Addr, error) { lc := net.ListenConfig{KeepAlive: 30 * time.Second} - l, err := lc.Listen(ctx, network, host) + l, err := lc.Listen(ctx, network, address) if err != nil { return nil, err } diff --git a/internal/proxy/proxy_test.go b/internal/proxy/proxy_test.go index 195740ab67..d850be7662 100644 --- a/internal/proxy/proxy_test.go +++ b/internal/proxy/proxy_test.go @@ -16,8 +16,10 @@ package proxy_test import ( "context" + "io/ioutil" "net" "os" + "path/filepath" "testing" "cloud.google.com/go/alloydbconn" @@ -28,9 +30,10 @@ import ( type fakeDialer struct{} type testCase struct { - desc string - in *proxy.Config - wantAddrs []string + desc string + in *proxy.Config + wantTCPAddrs []string + wantUnixAddrs []string } func (fakeDialer) Dial(ctx context.Context, inst string, opts ...alloydbconn.DialOption) (net.Conn, error) { @@ -41,10 +44,25 @@ func (fakeDialer) Close() error { return nil } +func createTempDir(t *testing.T) (string, func()) { + testDir, err := ioutil.TempDir("", "*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + return testDir, func() { + if err := os.RemoveAll(testDir); err != nil { + t.Logf("failed to cleanup temp dir: %v", err) + } + } +} + func TestClientInitialization(t *testing.T) { ctx := context.Background() - cluster1 := "/projects/proj/locations/region/clusters/clust/instances/inst1" - cluster2 := "/projects/proj/locations/region/clusters/clust/instances/inst2" + testDir, cleanup := createTempDir(t) + defer cleanup() + inst1 := "/projects/proj/locations/region/clusters/clust/instances/inst1" + inst2 := "/projects/proj/locations/region/clusters/clust/instances/inst2" + wantUnix := "proj.region.clust.inst1" tcs := []testCase{ { @@ -53,11 +71,11 @@ func TestClientInitialization(t *testing.T) { Addr: "127.0.0.1", Port: 5000, Instances: []proxy.InstanceConnConfig{ - {Name: cluster1}, - {Name: cluster2}, + {Name: inst1}, + {Name: inst2}, }, }, - wantAddrs: []string{"127.0.0.1:5000", "127.0.0.1:5001"}, + wantTCPAddrs: []string{"127.0.0.1:5000", "127.0.0.1:5001"}, }, { desc: "with instance address", @@ -65,10 +83,10 @@ func TestClientInitialization(t *testing.T) { Addr: "1.1.1.1", // bad address, binding shouldn't happen here. Port: 5000, Instances: []proxy.InstanceConnConfig{ - {Addr: "0.0.0.0", Name: cluster1}, + {Addr: "0.0.0.0", Name: inst1}, }, }, - wantAddrs: []string{"0.0.0.0:5000"}, + wantTCPAddrs: []string{"0.0.0.0:5000"}, }, { desc: "with instance port", @@ -76,10 +94,10 @@ func TestClientInitialization(t *testing.T) { Addr: "127.0.0.1", Port: 5000, Instances: []proxy.InstanceConnConfig{ - {Name: cluster1, Port: 6000}, + {Name: inst1, Port: 6000}, }, }, - wantAddrs: []string{"127.0.0.1:6000"}, + wantTCPAddrs: []string{"127.0.0.1:6000"}, }, { desc: "with global port and instance port", @@ -87,11 +105,11 @@ func TestClientInitialization(t *testing.T) { Addr: "127.0.0.1", Port: 5000, Instances: []proxy.InstanceConnConfig{ - {Name: cluster1}, - {Name: cluster2, Port: 6000}, + {Name: inst1}, + {Name: inst2, Port: 6000}, }, }, - wantAddrs: []string{ + wantTCPAddrs: []string{ "127.0.0.1:5000", "127.0.0.1:6000", }, @@ -102,15 +120,53 @@ func TestClientInitialization(t *testing.T) { Addr: "127.0.0.1", Port: 5432, // default port Instances: []proxy.InstanceConnConfig{ - {Name: cluster1}, - {Name: cluster2}, + {Name: inst1}, + {Name: inst2}, }, }, - wantAddrs: []string{ + wantTCPAddrs: []string{ "127.0.0.1:5432", "127.0.0.1:5433", }, }, + { + desc: "with a Unix socket", + in: &proxy.Config{ + UnixSocket: testDir, + Instances: []proxy.InstanceConnConfig{ + {Name: inst1}, + }, + }, + wantUnixAddrs: []string{ + filepath.Join(testDir, wantUnix, ".s.PGSQL.5432"), + }, + }, + { + desc: "with a global TCP host port and an instance Unix socket", + in: &proxy.Config{ + Addr: "127.0.0.1", + Port: 5000, + Instances: []proxy.InstanceConnConfig{ + {Name: inst1, UnixSocket: testDir}, + }, + }, + wantUnixAddrs: []string{ + filepath.Join(testDir, wantUnix, ".s.PGSQL.5432"), + }, + }, + { + desc: "with a global Unix socket and an instance TCP port", + in: &proxy.Config{ + Addr: "127.0.0.1", + UnixSocket: testDir, + Instances: []proxy.InstanceConnConfig{ + {Name: inst1, Port: 5000}, + }, + }, + wantTCPAddrs: []string{ + "127.0.0.1:5000", + }, + }, } _, isFlex := os.LookupEnv("FLEX") if !isFlex { @@ -120,10 +176,10 @@ func TestClientInitialization(t *testing.T) { Addr: "::1", Port: 5000, Instances: []proxy.InstanceConnConfig{ - {Name: cluster1}, + {Name: inst1}, }, }, - wantAddrs: []string{"[::1]:5000"}, + wantTCPAddrs: []string{"[::1]:5000"}, }) } @@ -134,7 +190,7 @@ func TestClientInitialization(t *testing.T) { t.Fatalf("want error = nil, got = %v", err) } defer c.Close() - for _, addr := range tc.wantAddrs { + for _, addr := range tc.wantTCPAddrs { conn, err := net.Dial("tcp", addr) if err != nil { t.Fatalf("want error = nil, got = %v", err) @@ -144,6 +200,44 @@ func TestClientInitialization(t *testing.T) { t.Logf("failed to close connection: %v", err) } } + + for _, addr := range tc.wantUnixAddrs { + conn, err := net.Dial("unix", addr) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + err = conn.Close() + if err != nil { + t.Logf("failed to close connection: %v", err) + } + } }) } } + +func TestClientInitializationWorksRepeatedly(t *testing.T) { + // The client creates a Unix socket on initial startup and does not remove + // it on shutdown. This test ensures the existing socket does not cause + // problems for a second invocation. + ctx := context.Background() + testDir, cleanup := createTempDir(t) + defer cleanup() + + in := &proxy.Config{ + UnixSocket: testDir, + Instances: []proxy.InstanceConnConfig{ + {Name: "/projects/proj/locations/region/clusters/clust/instances/inst1"}, + }, + } + c, err := proxy.NewClient(ctx, fakeDialer{}, &cobra.Command{}, in) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + c.Close() + + c, err = proxy.NewClient(ctx, fakeDialer{}, &cobra.Command{}, in) + if err != nil { + t.Fatalf("want error = nil, got = %v", err) + } + c.Close() +} diff --git a/tests/alloydb_test.go b/tests/alloydb_test.go index 6fc8787168..59ca2c44ea 100644 --- a/tests/alloydb_test.go +++ b/tests/alloydb_test.go @@ -17,10 +17,12 @@ package tests import ( "flag" "fmt" + "io/ioutil" "os" "testing" "cloud.google.com/go/alloydbconn/driver/pgxv4" + "github.com/GoogleCloudPlatform/alloydb-auth-proxy/internal/proxy" ) var ( @@ -60,6 +62,38 @@ func TestPostgresTCP(t *testing.T) { proxyConnTest(t, []string{*alloydbConnName}, "alloydb1", dsn) } +func createTempDir(t *testing.T) (string, func()) { + testDir, err := ioutil.TempDir("", "*") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + return testDir, func() { + if err := os.RemoveAll(testDir); err != nil { + t.Logf("failed to cleanup temp dir: %v", err) + } + } +} + +func TestPostgresUnix(t *testing.T) { + if testing.Short() { + t.Skip("skipping Postgres integration tests") + } + requirePostgresVars(t) + tmpDir, cleanup := createTempDir(t) + defer cleanup() + + dir, err := proxy.UnixSocketDir(tmpDir, *alloydbConnName) + if err != nil { + t.Fatalf("invalid connection name: %v", *alloydbConnName) + } + dsn := fmt.Sprintf("host=%s user=%s password=%s database=%s sslmode=disable", + dir, + *alloydbUser, *alloydbPass, *alloydbDB) + + proxyConnTest(t, + []string{"--unix-socket", tmpDir, *alloydbConnName}, "pgx", dsn) +} + func TestPostgresAuthWithToken(t *testing.T) { if testing.Short() { t.Skip("skipping Postgres integration tests")