From 07ad18178d7ddb1fb71850937b54b0b4b5d5aa72 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 4 May 2017 20:15:17 -0700 Subject: [PATCH 1/2] pkg/srv: package for SRV utilities Trying to decouple the v2 client from SRV code. Can't move into discovery/ since that creates a circular dependency. So, give up and move all the SRV code into a new package. --- client/discover.go | 19 ++++++ client/srv.go | 65 ------------------ client/srv_test.go | 102 ----------------------------- embed/config.go | 14 ++-- etcdmain/util.go | 5 +- {discovery => pkg/srv}/srv.go | 78 ++++++++++++++++------ {discovery => pkg/srv}/srv_test.go | 85 +++++++++++++++++++++++- 7 files changed, 170 insertions(+), 198 deletions(-) delete mode 100644 client/srv.go delete mode 100644 client/srv_test.go rename {discovery => pkg/srv}/srv.go (50%) rename {discovery => pkg/srv}/srv_test.go (59%) diff --git a/client/discover.go b/client/discover.go index bfd7aec93f5..442e35fe543 100644 --- a/client/discover.go +++ b/client/discover.go @@ -14,8 +14,27 @@ package client +import ( + "github.com/coreos/etcd/pkg/srv" +) + // Discoverer is an interface that wraps the Discover method. type Discoverer interface { // Discover looks up the etcd servers for the domain. Discover(domain string) ([]string, error) } + +type srvDiscover struct{} + +// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records. +func NewSRVDiscover() Discoverer { + return &srvDiscover{} +} + +func (d *srvDiscover) Discover(domain string) ([]string, error) { + srvs, err := srv.GetClient("etcd-client", domain) + if err != nil { + return nil, err + } + return srvs.Endpoints, nil +} diff --git a/client/srv.go b/client/srv.go deleted file mode 100644 index fdfa3435921..00000000000 --- a/client/srv.go +++ /dev/null @@ -1,65 +0,0 @@ -// Copyright 2015 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "fmt" - "net" - "net/url" -) - -var ( - // indirection for testing - lookupSRV = net.LookupSRV -) - -type srvDiscover struct{} - -// NewSRVDiscover constructs a new Discoverer that uses the stdlib to lookup SRV records. -func NewSRVDiscover() Discoverer { - return &srvDiscover{} -} - -// Discover looks up the etcd servers for the domain. -func (d *srvDiscover) Discover(domain string) ([]string, error) { - var urls []*url.URL - - updateURLs := func(service, scheme string) error { - _, addrs, err := lookupSRV(service, "tcp", domain) - if err != nil { - return err - } - for _, srv := range addrs { - urls = append(urls, &url.URL{ - Scheme: scheme, - Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), - }) - } - return nil - } - - errHTTPS := updateURLs("etcd-client-ssl", "https") - errHTTP := updateURLs("etcd-client", "http") - - if errHTTPS != nil && errHTTP != nil { - return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) - } - - endpoints := make([]string, len(urls)) - for i := range urls { - endpoints[i] = urls[i].String() - } - return endpoints, nil -} diff --git a/client/srv_test.go b/client/srv_test.go deleted file mode 100644 index 64cf6032322..00000000000 --- a/client/srv_test.go +++ /dev/null @@ -1,102 +0,0 @@ -// Copyright 2015 The etcd Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package client - -import ( - "errors" - "net" - "reflect" - "testing" -) - -func TestSRVDiscover(t *testing.T) { - defer func() { lookupSRV = net.LookupSRV }() - - tests := []struct { - withSSL []*net.SRV - withoutSSL []*net.SRV - expected []string - }{ - { - []*net.SRV{}, - []*net.SRV{}, - []string{}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{}, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 7001}, - }, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, - }, - { - []*net.SRV{ - {Target: "10.0.0.1", Port: 2480}, - {Target: "10.0.0.2", Port: 2480}, - {Target: "10.0.0.3", Port: 2480}, - }, - []*net.SRV{ - {Target: "10.0.0.1", Port: 7001}, - }, - []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, - }, - { - []*net.SRV{ - {Target: "a.example.com", Port: 2480}, - {Target: "b.example.com", Port: 2480}, - {Target: "c.example.com", Port: 2480}, - }, - []*net.SRV{}, - []string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"}, - }, - } - - for i, tt := range tests { - lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { - if service == "etcd-client-ssl" { - return "", tt.withSSL, nil - } - if service == "etcd-client" { - return "", tt.withoutSSL, nil - } - return "", nil, errors.New("Unknown service in mock") - } - - d := NewSRVDiscover() - - endpoints, err := d.Discover("example.com") - if err != nil { - t.Fatalf("%d: err: %#v", i, err) - } - - if !reflect.DeepEqual(endpoints, tt.expected) { - t.Errorf("#%d: endpoints = %v, want %v", i, endpoints, tt.expected) - } - - } -} diff --git a/embed/config.go b/embed/config.go index 93431d1c672..e3926f66cb4 100644 --- a/embed/config.go +++ b/embed/config.go @@ -22,10 +22,10 @@ import ( "net/url" "strings" - "github.com/coreos/etcd/discovery" "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/pkg/cors" "github.com/coreos/etcd/pkg/netutil" + "github.com/coreos/etcd/pkg/srv" "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" @@ -321,11 +321,15 @@ func (cfg *Config) PeerURLsMapAndToken(which string) (urlsmap types.URLsMap, tok urlsmap[cfg.Name] = cfg.APUrls token = cfg.Durl case cfg.DNSCluster != "": - var clusterStr string - clusterStr, err = discovery.SRVGetCluster(cfg.Name, cfg.DNSCluster, cfg.APUrls) - if err != nil { - return nil, "", err + clusterStrs, cerr := srv.GetCluster("etcd-server", cfg.Name, cfg.DNSCluster, cfg.APUrls) + if cerr != nil { + plog.Errorf("couldn't resolve during SRV discovery (%v)", cerr) + return nil, "", cerr + } + for _, s := range clusterStrs { + plog.Noticef("got bootstrap from DNS for etcd-server at %s", s) } + clusterStr := strings.Join(clusterStrs, ",") if strings.Contains(clusterStr, "https://") && cfg.PeerTLSInfo.CAFile == "" { cfg.PeerTLSInfo.ServerName = cfg.DNSCluster } diff --git a/etcdmain/util.go b/etcdmain/util.go index 23e19b44057..5de07275b5b 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -18,7 +18,7 @@ import ( "fmt" "os" - "github.com/coreos/etcd/client" + "github.com/coreos/etcd/pkg/srv" "github.com/coreos/etcd/pkg/transport" ) @@ -26,11 +26,12 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string if dns == "" { return nil } - endpoints, err := client.NewSRVDiscover().Discover(dns) + srvs, err := srv.GetClient("etcd-client", dns) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } + endpoints = srvs.Endpoints plog.Infof("discovered the cluster %s from %s", endpoints, dns) if insecure { return endpoints diff --git a/discovery/srv.go b/pkg/srv/srv.go similarity index 50% rename from discovery/srv.go rename to pkg/srv/srv.go index 782b6888f54..71a0af7956e 100644 --- a/discovery/srv.go +++ b/pkg/srv/srv.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package discovery +package srv import ( "fmt" @@ -25,14 +25,13 @@ import ( var ( // indirection for testing - lookupSRV = net.LookupSRV + lookupSRV = net.LookupSRV // net.DefaultResolver.LookupSRV when ctxs don't conflict resolveTCPAddr = net.ResolveTCPAddr ) -// SRVGetCluster gets the cluster information via DNS discovery. -// TODO(barakmich): Currently ignores priority and weight (as they don't make as much sense for a bootstrap) +// GetCluster gets the cluster information via DNS discovery. // Also sees each entry as a separate instance. -func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { +func GetCluster(service, name, dns string, apurls types.URLs) ([]string, error) { tempName := int(0) tcp2ap := make(map[string]url.URL) @@ -40,8 +39,7 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { for _, url := range apurls { tcpAddr, err := resolveTCPAddr("tcp", url.Host) if err != nil { - plog.Errorf("couldn't resolve host %s during SRV discovery", url.Host) - return "", err + return nil, err } tcp2ap[tcpAddr.String()] = url } @@ -55,9 +53,9 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { for _, srv := range addrs { port := fmt.Sprintf("%d", srv.Port) host := net.JoinHostPort(srv.Target, port) - tcpAddr, err := resolveTCPAddr("tcp", host) - if err != nil { - plog.Warningf("couldn't resolve host %s during SRV discovery", host) + tcpAddr, terr := resolveTCPAddr("tcp", host) + if terr != nil { + terr = err continue } n := "" @@ -73,31 +71,69 @@ func SRVGetCluster(name, dns string, apurls types.URLs) (string, error) { shortHost := strings.TrimSuffix(srv.Target, ".") urlHost := net.JoinHostPort(shortHost, port) stringParts = append(stringParts, fmt.Sprintf("%s=%s://%s", n, scheme, urlHost)) - plog.Noticef("got bootstrap from DNS for %s at %s://%s", service, scheme, urlHost) if ok && url.Scheme != scheme { - plog.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) + err = fmt.Errorf("bootstrap at %s from DNS for %s has scheme mismatch with expected peer %s", scheme+"://"+urlHost, service, url.String()) } } + if len(stringParts) == 0 { + return err + } return nil } failCount := 0 - err := updateNodeMap("etcd-server-ssl", "https") + err := updateNodeMap(service+"-ssl", "https") srvErr := make([]string, 2) if err != nil { - srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _etcd-server-ssl %s", err) + srvErr[0] = fmt.Sprintf("error querying DNS SRV records for _%s-ssl %s", service, err) failCount++ } - err = updateNodeMap("etcd-server", "http") + err = updateNodeMap(service, "http") if err != nil { - srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _etcd-server %s", err) + srvErr[1] = fmt.Sprintf("error querying DNS SRV records for _%s %s", service, err) failCount++ } if failCount == 2 { - plog.Warningf(srvErr[0]) - plog.Warningf(srvErr[1]) - plog.Errorf("SRV discovery failed: too many errors querying DNS SRV records") - return "", err + return nil, fmt.Errorf("srv: too many errors querying DNS SRV records (%q, %q)", srvErr[0], srvErr[1]) + } + return stringParts, nil +} + +type SRVClients struct { + Endpoints []string + SRVs []*net.SRV +} + +// GetClient looks up the client endpoints for a service and domain. +func GetClient(service, domain string) (*SRVClients, error) { + var urls []*url.URL + var srvs []*net.SRV + + updateURLs := func(service, scheme string) error { + _, addrs, err := lookupSRV(service, "tcp", domain) + if err != nil { + return err + } + for _, srv := range addrs { + urls = append(urls, &url.URL{ + Scheme: scheme, + Host: net.JoinHostPort(srv.Target, fmt.Sprintf("%d", srv.Port)), + }) + } + srvs = append(srvs, addrs...) + return nil + } + + errHTTPS := updateURLs(service+"-ssl", "https") + errHTTP := updateURLs(service, "http") + + if errHTTPS != nil && errHTTP != nil { + return nil, fmt.Errorf("dns lookup errors: %s and %s", errHTTPS, errHTTP) + } + + endpoints := make([]string, len(urls)) + for i := range urls { + endpoints[i] = urls[i].String() } - return strings.Join(stringParts, ","), nil + return &SRVClients{Endpoints: endpoints, SRVs: srvs}, nil } diff --git a/discovery/srv_test.go b/pkg/srv/srv_test.go similarity index 59% rename from discovery/srv_test.go rename to pkg/srv/srv_test.go index b9914a5544c..0386c9d2a09 100644 --- a/discovery/srv_test.go +++ b/pkg/srv/srv_test.go @@ -12,11 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -package discovery +package srv import ( "errors" "net" + "reflect" "strings" "testing" @@ -110,12 +111,90 @@ func TestSRVGetCluster(t *testing.T) { return "", nil, errors.New("Unknown service in mock") } urls := testutil.MustNewURLs(t, tt.urls) - str, err := SRVGetCluster(name, "example.com", urls) + str, err := GetCluster("etcd-server", name, "example.com", urls) if err != nil { t.Fatalf("%d: err: %#v", i, err) } - if str != tt.expected { + if strings.Join(str, ",") != tt.expected { t.Errorf("#%d: cluster = %s, want %s", i, str, tt.expected) } } } + +func TestSRVDiscover(t *testing.T) { + defer func() { lookupSRV = net.LookupSRV }() + + tests := []struct { + withSSL []*net.SRV + withoutSSL []*net.SRV + expected []string + }{ + { + []*net.SRV{}, + []*net.SRV{}, + []string{}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{}, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480"}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{ + {Target: "10.0.0.1", Port: 7001}, + }, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + }, + { + []*net.SRV{ + {Target: "10.0.0.1", Port: 2480}, + {Target: "10.0.0.2", Port: 2480}, + {Target: "10.0.0.3", Port: 2480}, + }, + []*net.SRV{ + {Target: "10.0.0.1", Port: 7001}, + }, + []string{"https://10.0.0.1:2480", "https://10.0.0.2:2480", "https://10.0.0.3:2480", "http://10.0.0.1:7001"}, + }, + { + []*net.SRV{ + {Target: "a.example.com", Port: 2480}, + {Target: "b.example.com", Port: 2480}, + {Target: "c.example.com", Port: 2480}, + }, + []*net.SRV{}, + []string{"https://a.example.com:2480", "https://b.example.com:2480", "https://c.example.com:2480"}, + }, + } + + for i, tt := range tests { + lookupSRV = func(service string, proto string, domain string) (string, []*net.SRV, error) { + if service == "etcd-client-ssl" { + return "", tt.withSSL, nil + } + if service == "etcd-client" { + return "", tt.withoutSSL, nil + } + return "", nil, errors.New("Unknown service in mock") + } + + srvs, err := GetClient("etcd-client", "example.com") + if err != nil { + t.Fatalf("%d: err: %#v", i, err) + } + + if !reflect.DeepEqual(srvs.Endpoints, tt.expected) { + t.Errorf("#%d: endpoints = %v, want %v", i, srvs.Endpoints, tt.expected) + } + + } +} From c2328140035d80e7343af1bab65c6b7c6103cd42 Mon Sep 17 00:00:00 2001 From: Anthony Romano Date: Thu, 4 May 2017 21:57:54 -0700 Subject: [PATCH 2/2] etcdmain, tcpproxy: srv-priority policy Adds DNS SRV weighting and priorities to gateway. Partially addresses #4378 --- etcdmain/gateway.go | 27 ++++++--- etcdmain/grpc_proxy.go | 5 +- etcdmain/util.go | 24 ++++++-- proxy/tcpproxy/userspace.go | 95 ++++++++++++++++++++++++-------- proxy/tcpproxy/userspace_test.go | 4 +- 5 files changed, 115 insertions(+), 40 deletions(-) diff --git a/etcdmain/gateway.go b/etcdmain/gateway.go index 1a72bddcf08..5487414ebd5 100644 --- a/etcdmain/gateway.go +++ b/etcdmain/gateway.go @@ -91,17 +91,28 @@ func stripSchema(eps []string) []string { return endpoints } -func startGateway(cmd *cobra.Command, args []string) { - endpoints := gatewayEndpoints - if eps := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery); len(eps) != 0 { - endpoints = eps +func startGateway(cmd *cobra.Command, args []string) { + srvs := discoverEndpoints(gatewayDNSCluster, gatewayCA, gatewayInsecureDiscovery) + if len(srvs.Endpoints) == 0 { + // no endpoints discovered, fall back to provided endpoints + srvs.Endpoints = gatewayEndpoints } - // Strip the schema from the endpoints because we start just a TCP proxy - endpoints = stripSchema(endpoints) + srvs.Endpoints = stripSchema(srvs.Endpoints) + if len(srvs.SRVs) == 0 { + for _, ep := range srvs.Endpoints { + h, p, err := net.SplitHostPort(ep) + if err != nil { + plog.Fatalf("error parsing endpoint %q", ep) + } + var port uint16 + fmt.Sscanf(p, "%d", &port) + srvs.SRVs = append(srvs.SRVs, &net.SRV{Target: h, Port: port}) + } + } - if len(endpoints) == 0 { + if len(srvs.Endpoints) == 0 { plog.Fatalf("no endpoints found") } @@ -113,7 +124,7 @@ func startGateway(cmd *cobra.Command, args []string) { tp := tcpproxy.TCPProxy{ Listener: l, - Endpoints: endpoints, + Endpoints: srvs.SRVs, MonitorInterval: getewayRetryDelay, } diff --git a/etcdmain/grpc_proxy.go b/etcdmain/grpc_proxy.go index 1f701ba1297..ae5af8bbf84 100644 --- a/etcdmain/grpc_proxy.go +++ b/etcdmain/grpc_proxy.go @@ -106,8 +106,9 @@ func startGRPCProxy(cmd *cobra.Command, args []string) { os.Exit(1) } - if eps := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery); len(eps) != 0 { - grpcProxyEndpoints = eps + srvs := discoverEndpoints(grpcProxyDNSCluster, grpcProxyCA, grpcProxyInsecureDiscovery) + if len(srvs.Endpoints) != 0 { + grpcProxyEndpoints = srvs.Endpoints } l, err := net.Listen("tcp", grpcProxyListenAddr) diff --git a/etcdmain/util.go b/etcdmain/util.go index 5de07275b5b..9657271d53a 100644 --- a/etcdmain/util.go +++ b/etcdmain/util.go @@ -22,19 +22,19 @@ import ( "github.com/coreos/etcd/pkg/transport" ) -func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string) { +func discoverEndpoints(dns string, ca string, insecure bool) (s srv.SRVClients) { if dns == "" { - return nil + return s } srvs, err := srv.GetClient("etcd-client", dns) if err != nil { fmt.Fprintln(os.Stderr, err) os.Exit(1) } - endpoints = srvs.Endpoints + endpoints := srvs.Endpoints plog.Infof("discovered the cluster %s from %s", endpoints, dns) if insecure { - return endpoints + return *srvs } // confirm TLS connections are good tlsInfo := transport.TLSInfo{ @@ -47,5 +47,19 @@ func discoverEndpoints(dns string, ca string, insecure bool) (endpoints []string plog.Warningf("%v", err) } plog.Infof("using discovered endpoints %v", endpoints) - return endpoints + + // map endpoints back to SRVClients struct with SRV data + eps := make(map[string]struct{}) + for _, ep := range endpoints { + eps[ep] = struct{}{} + } + for i := range srvs.Endpoints { + if _, ok := eps[srvs.Endpoints[i]]; !ok { + continue + } + s.Endpoints = append(s.Endpoints, srvs.Endpoints[i]) + s.SRVs = append(s.SRVs, srvs.SRVs[i]) + } + + return s } diff --git a/proxy/tcpproxy/userspace.go b/proxy/tcpproxy/userspace.go index 5de017a70de..01e40a24c5e 100644 --- a/proxy/tcpproxy/userspace.go +++ b/proxy/tcpproxy/userspace.go @@ -15,7 +15,9 @@ package tcpproxy import ( + "fmt" "io" + "math/rand" "net" "sync" "time" @@ -29,6 +31,7 @@ var ( type remote struct { mu sync.Mutex + srv *net.SRV addr string inactive bool } @@ -59,14 +62,14 @@ func (r *remote) isActive() bool { type TCPProxy struct { Listener net.Listener - Endpoints []string + Endpoints []*net.SRV MonitorInterval time.Duration donec chan struct{} - mu sync.Mutex // guards the following fields - remotes []*remote - nextRemote int + mu sync.Mutex // guards the following fields + remotes []*remote + pickCount int // for round robin } func (tp *TCPProxy) Run() error { @@ -74,11 +77,12 @@ func (tp *TCPProxy) Run() error { if tp.MonitorInterval == 0 { tp.MonitorInterval = 5 * time.Minute } - for _, ep := range tp.Endpoints { - tp.remotes = append(tp.remotes, &remote{addr: ep}) + for _, srv := range tp.Endpoints { + addr := fmt.Sprintf("%s:%d", srv.Target, srv.Port) + tp.remotes = append(tp.remotes, &remote{srv: srv, addr: addr}) } - plog.Printf("ready to proxy client requests to %v", tp.Endpoints) + plog.Printf("ready to proxy client requests to %+v", tp.Endpoints) go tp.runMonitor() for { in, err := tp.Listener.Accept() @@ -90,10 +94,61 @@ func (tp *TCPProxy) Run() error { } } -func (tp *TCPProxy) numRemotes() int { - tp.mu.Lock() - defer tp.mu.Unlock() - return len(tp.remotes) +func (tp *TCPProxy) pick() *remote { + var weighted []*remote + var unweighted []*remote + + bestPr := uint16(65535) + w := 0 + // find best priority class + for _, r := range tp.remotes { + switch { + case !r.isActive(): + case r.srv.Priority < bestPr: + bestPr = r.srv.Priority + w = 0 + weighted, unweighted = nil, nil + unweighted = []*remote{r} + fallthrough + case r.srv.Priority == bestPr: + if r.srv.Weight > 0 { + weighted = append(weighted, r) + w += int(r.srv.Weight) + } else { + unweighted = append(unweighted, r) + } + } + } + if weighted != nil { + if len(unweighted) > 0 && rand.Intn(100) == 1 { + // In the presence of records containing weights greater + // than 0, records with weight 0 should have a very small + // chance of being selected. + r := unweighted[tp.pickCount%len(unweighted)] + tp.pickCount++ + return r + } + // choose a uniform random number between 0 and the sum computed + // (inclusive), and select the RR whose running sum value is the + // first in the selected order + choose := rand.Intn(w) + for i := 0; i < len(weighted); i++ { + choose -= int(weighted[i].srv.Weight) + if choose <= 0 { + return weighted[i] + } + } + } + if unweighted != nil { + for i := 0; i < len(tp.remotes); i++ { + picked := tp.remotes[tp.pickCount%len(tp.remotes)] + tp.pickCount++ + if picked.isActive() { + return picked + } + } + } + return nil } func (tp *TCPProxy) serve(in net.Conn) { @@ -102,10 +157,12 @@ func (tp *TCPProxy) serve(in net.Conn) { out net.Conn ) - for i := 0; i < tp.numRemotes(); i++ { + for { + tp.mu.Lock() remote := tp.pick() - if !remote.isActive() { - continue + tp.mu.Unlock() + if remote == nil { + break } // TODO: add timeout out, err = net.Dial("tcp", remote.addr) @@ -132,16 +189,6 @@ func (tp *TCPProxy) serve(in net.Conn) { in.Close() } -// pick picks a remote in round-robin fashion -func (tp *TCPProxy) pick() *remote { - tp.mu.Lock() - defer tp.mu.Unlock() - - picked := tp.remotes[tp.nextRemote] - tp.nextRemote = (tp.nextRemote + 1) % len(tp.remotes) - return picked -} - func (tp *TCPProxy) runMonitor() { for { select { diff --git a/proxy/tcpproxy/userspace_test.go b/proxy/tcpproxy/userspace_test.go index e239c19c662..bf65f570c21 100644 --- a/proxy/tcpproxy/userspace_test.go +++ b/proxy/tcpproxy/userspace_test.go @@ -42,9 +42,11 @@ func TestUserspaceProxy(t *testing.T) { t.Fatal(err) } + var port uint16 + fmt.Sscanf(u.Port(), "%d", &port) p := TCPProxy{ Listener: l, - Endpoints: []string{u.Host}, + Endpoints: []*net.SRV{{Target: u.Hostname(), Port: port}}, } go p.Run() defer p.Stop()