diff --git a/etcdmain/gateway.go b/etcdmain/gateway.go index 1a72bddcf08..8d4ce022c94 100644 --- a/etcdmain/gateway.go +++ b/etcdmain/gateway.go @@ -29,6 +29,7 @@ import ( var ( gatewayListenAddr string gatewayEndpoints []string + gatewayEndpointPolicy string gatewayDNSCluster string gatewayInsecureDiscovery bool getewayRetryDelay time.Duration @@ -67,6 +68,7 @@ func newGatewayStartCommand() *cobra.Command { cmd.Flags().StringVar(&gatewayListenAddr, "listen-addr", "127.0.0.1:23790", "listen address") cmd.Flags().StringVar(&gatewayDNSCluster, "discovery-srv", "", "DNS domain used to bootstrap initial cluster") + cmd.Flags().StringVar(&gatewayEndpointPolicy, "endpoint-policy", "round-robin", "Policy for selecting next connection's endpoint (round-robin, srv-priority)") cmd.Flags().BoolVar(&gatewayInsecureDiscovery, "insecure-discovery", false, "accept insecure SRV records") cmd.Flags().StringVar(&gatewayCA, "trusted-ca-file", "", "path to the client server TLS CA file.") @@ -91,17 +93,27 @@ func stripSchema(eps []string) []string { return endpoints } -func startGateway(cmd *cobra.Command, args []string) { - endpoints := gatewayEndpoints - if eps := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery); len(eps) != 0 { - endpoints = eps +func startGateway(cmd *cobra.Command, args []string) { + srvs := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery) + if len(srvs.Endpoints) != 0 { + srvs.Endpoints = gatewayEndpoints } - // Strip the schema from the endpoints because we start just a TCP proxy - endpoints = stripSchema(endpoints) + srvs.Endpoints = stripSchema(srvs.Endpoints) + if len(srvs.SRVs) == 0 { + for _, ep := range srvs.Endpoints { + h, p, err := net.SplitHostPort(ep) + if err != nil { + plog.Fatalf("error parsing endpoint %q", ep) + } + var port uint16 + fmt.Sscanf(p, "%d", &port) + srvs.SRVs = append(srvs.SRVs, &net.SRV{Target: h, Port: port}) + } + } - if len(endpoints) == 0 { + if len(srvs.Endpoints) == 0 { plog.Fatalf("no endpoints found") } @@ -111,10 +123,17 @@ func startGateway(cmd *cobra.Command, args []string) { os.Exit(1) } + picker, perr := tcpproxy.NewPicker(gatewayEndpointPolicy) + if perr != nil { + fmt.Fprintln(os.Stderr, err) + os.Exit(1) + } + tp := tcpproxy.TCPProxy{ Listener: l, - Endpoints: endpoints, + Endpoints: srvs.SRVs, MonitorInterval: getewayRetryDelay, + Picker: picker, } // At this point, etcd gateway listener is initialized diff --git a/etcdmain/grpc_proxy.go b/etcdmain/grpc_proxy.go index 1f701ba1297..ae5af8bbf84 100644 --- a/etcdmain/grpc_proxy.go +++ b/etcdmain/grpc_proxy.go @@ -106,8 +106,9 @@ func startGRPCProxy(cmd *cobra.Command, args []string) { os.Exit(1) } - if eps := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery); len(eps) != 0 { - grpcProxyEndpoints = eps + srvs := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery) + if len(srvs.Endpoints) != 0 { + grpcProxyEndpoints = srvs.Endpoints } l, err := net.Listen("tcp", grpcProxyListenAddr) diff --git a/etcdmain/util.go b/etcdmain/util.go index 5de07275b5b..9657271d53a 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -22,19 +22,19 @@ import ( "github.com/coreos/etcd/pkg/transport" ) -func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string) { +func discoverEndpoints(dns string, ca string, insecure bool) (s srv.SRVClients) { if dns == "" { - return nil + return s } srvs, err := srv.GetClient("etcd-client", dns) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } - endpoints = srvs.Endpoints + endpoints := srvs.Endpoints plog.Infof("discovered the cluster %s from %s", endpoints, dns) if insecure { - return endpoints + return *srvs } // confirm TLS connections are good tlsInfo := transport.TLSInfo{ @@ -47,5 +47,19 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string plog.Warningf("%v", err) } plog.Infof("using discovered endpoints %v", endpoints) - return endpoints + + // map endpoints back to SRVClients struct with SRV data + eps := make(map[string]struct{}) + for _, ep := range endpoints { + eps[ep] = struct{}{} + } + for i := range srvs.Endpoints { + if _, ok := eps[srvs.Endpoints[i]]; !ok { + continue + } + s.Endpoints = append(s.Endpoints, srvs.Endpoints[i]) + s.SRVs = append(s.SRVs, srvs.SRVs[i]) + } + + return s } diff --git a/proxy/tcpproxy/userspace.go b/proxy/tcpproxy/userspace.go index 5de017a70de..a5e1e998ea2 100644 --- a/proxy/tcpproxy/userspace.go +++ b/proxy/tcpproxy/userspace.go @@ -15,7 +15,9 @@ package tcpproxy import ( + "fmt" "io" + "math/rand" "net" "sync" "time" @@ -29,6 +31,7 @@ var ( type remote struct { mu sync.Mutex + srv *net.SRV addr string inactive bool } @@ -59,26 +62,30 @@ func (r *remote) isActive() bool { type TCPProxy struct { Listener net.Listener - Endpoints []string + Endpoints []*net.SRV MonitorInterval time.Duration + Picker PickerFunc donec chan struct{} - mu sync.Mutex // guards the following fields - remotes []*remote - nextRemote int + mu sync.Mutex // guards the following fields + remotes []*remote } func (tp *TCPProxy) Run() error { + if tp.Picker == nil { + tp.Picker, _ = NewPicker("round-robin") + } tp.donec = make(chan struct{}) if tp.MonitorInterval == 0 { tp.MonitorInterval = 5 * time.Minute } - for _, ep := range tp.Endpoints { - tp.remotes = append(tp.remotes, &remote{addr: ep}) + for _, srv := range tp.Endpoints { + addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr}) } - plog.Printf("ready to proxy client requests to %v", tp.Endpoints) + plog.Printf("ready to proxy client requests to %+v", tp.Endpoints) go tp.runMonitor() for { in, err := tp.Listener.Accept() @@ -90,6 +97,61 @@ func (tp *TCPProxy) Run() error { } } +type PickerFunc func(tp *TCPProxy) *remote + +// NewPicker returns a picker to choose remotes for the tcp proxy. +func NewPicker(name string) (PickerFunc, error) { + switch name { + case "srv-priority": + return srvPicker, nil + case "round-robin": + nextRemote := 0 + f := func(tp *TCPProxy) *remote { + for i := 0; i < len(tp.remotes); i++ { + picked := tp.remotes[nextRemote] + nextRemote = (nextRemote + 1) % len(tp.remotes) + if picked.isActive() { + return picked + } + } + return nil + } + return f, nil + default: + } + return nil, fmt.Errorf("unknown picker %q", name) +} + +func srvPicker(tp *TCPProxy) *remote { + bestPr := uint16(65535) + w := 0 + var candidates []*remote + // find best priority class + for _, r := range tp.remotes { + switch { + case !r.isActive(): + case r.srv.Priority < bestPr: + bestPr = r.srv.Priority + candidates = []*remote{r} + w = int(r.srv.Weight) + 1 + case r.srv.Priority == bestPr: + candidates = append(candidates, r) + w += int(r.srv.Weight) + 1 + } + } + // randomly choose by weight + if len(candidates) > 0 { + choose := rand.Intn(w) + for i := 0; i < len(candidates); i++ { + choose -= int(candidates[i].srv.Weight) + 1 + if choose <= 0 { + return candidates[i] + } + } + } + return nil +} + func (tp *TCPProxy) numRemotes() int { tp.mu.Lock() defer tp.mu.Unlock() @@ -102,10 +164,12 @@ func (tp *TCPProxy) serve(in net.Conn) { out net.Conn ) - for i := 0; i < tp.numRemotes(); i++ { - remote := tp.pick() - if !remote.isActive() { - continue + for { + tp.mu.Lock() + remote := tp.Picker(tp) + tp.mu.Unlock() + if remote == nil { + break } // TODO: add timeout out, err = net.Dial("tcp", remote.addr) @@ -132,16 +196,6 @@ func (tp *TCPProxy) serve(in net.Conn) { in.Close() } -// pick picks a remote in round-robin fashion -func (tp *TCPProxy) pick() *remote { - tp.mu.Lock() - defer tp.mu.Unlock() - - picked := tp.remotes[tp.nextRemote] - tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes) - return picked -} - func (tp *TCPProxy) runMonitor() { for { select { diff --git a/proxy/tcpproxy/userspace_test.go b/proxy/tcpproxy/userspace_test.go index e239c19c662..bf65f570c21 100644 --- a/proxy/tcpproxy/userspace_test.go +++ b/proxy/tcpproxy/userspace_test.go @@ -42,9 +42,11 @@ func TestUserspaceProxy(t *testing.T) { t.Fatal(err) } + var port uint16 + fmt.Sscanf(u.Port(), "%d", &port) p := TCPProxy{ Listener: l, - Endpoints: []string{u.Host}, + Endpoints: []*net.SRV{{Target: u.Hostname(), Port: port}}, } go p.Run() defer p.Stop()