diff --git a/embed/etcd.go b/embed/etcd.go index 2b5cf537384..ce64b839c64 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -82,7 +82,7 @@ type Etcd struct { type peerListener struct { net.Listener serve func() error - close func(context.Context) error + close func(time.Duration) error } // StartEtcd launches the etcd server and HTTP handlers for client/server communication. @@ -100,9 +100,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { return } if !serving { - // errored before starting gRPC server for serveCtx.grpcServerC + // errored before starting gRPC server for serveCtx.serversC for _, sctx := range e.sctxs { - close(sctx.grpcServerC) + close(sctx.serversC) } } e.Close() @@ -219,23 +219,30 @@ func (e *Etcd) Config() Config { return e.cfg } +// Close gracefully shuts down all servers/listeners. func (e *Etcd) Close() { e.closeOnce.Do(func() { close(e.stopc) }) + timeout := 2 * time.Second + if e.Server != nil { + timeout = e.Server.Cfg.ReqTimeout() + } for _, sctx := range e.sctxs { - for gs := range sctx.grpcServerC { - e.stopGRPCServer(gs) + for ss := range sctx.serversC { + stopServers(ss, timeout) } } for _, sctx := range e.sctxs { sctx.cancel() } + for i := range e.Clients { if e.Clients[i] != nil { e.Clients[i].Close() } } + for i := range e.metricsListeners { e.metricsListeners[i].Close() } @@ -248,32 +255,46 @@ func (e *Etcd) Close() { // close all idle connections in peer handler (wait up to 1-second) for i := range e.Peers { if e.Peers[i] != nil && e.Peers[i].close != nil { - ctx, cancel := context.WithTimeout(context.Background(), time.Second) - e.Peers[i].close(ctx) - cancel() + e.Peers[i].close(time.Second) } } } -func (e *Etcd) stopGRPCServer(gs *grpc.Server) { - timeout := 2 * time.Second - if e.Server != nil { - timeout = e.Server.Cfg.ReqTimeout() +func stopServers(ss *servers, timeout time.Duration) { + shutdownNow := func() { + // first, close the http.Server + ctx, cancel := context.WithTimeout(context.Background(), timeout) + ss.http.Shutdown(ctx) + cancel() + + // and then close grpc.Server; cancels all active RPCs + ss.grpc.Stop() + } + + // do not grpc.Server.GracefulStop with TLS enabled etcd server + // See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531 + // and https://github.com/coreos/etcd/issues/8916 + if ss.secure { + shutdownNow() + return } + ch := make(chan struct{}) go func() { defer close(ch) // close listeners to stop accepting new connections, // will block on any existing transports - gs.GracefulStop() + ss.grpc.GracefulStop() }() + // wait until all pending RPCs are finished select { case <-ch: case <-time.After(timeout): // took too long, manually close open transports // e.g. watch streams - gs.Stop() + shutdownNow() + // concurrent GracefulStop should be interrupted <-ch } @@ -297,7 +318,7 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { for i := range peers { if peers[i] != nil && peers[i].close != nil { plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String()) - peers[i].close(context.Background()) + peers[i].close(time.Second) } } }() @@ -311,13 +332,13 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String()) } } - peers[i] = &peerListener{close: func(context.Context) error { return nil }} + peers[i] = &peerListener{close: func(time.Duration) error { return nil }} peers[i].Listener, err = rafthttp.NewListener(u, &cfg.PeerTLSInfo) if err != nil { return nil, err } // once serve, overwrite with 'http.Server.Shutdown' - peers[i].close = func(context.Context) error { + peers[i].close = func(time.Duration) error { return peers[i].Listener.Close() } plog.Info("listening for peers on ", u.String()) @@ -334,6 +355,7 @@ func (e *Etcd) servePeers() (err error) { return err } } + for _, p := range e.Peers { gs := v3rpc.Server(e.Server, peerTLScfg) m := cmux.New(p.Listener) @@ -345,12 +367,12 @@ func (e *Etcd) servePeers() (err error) { } go srv.Serve(m.Match(cmux.Any())) p.serve = func() error { return m.Serve() } - p.close = func(ctx context.Context) error { + p.close = func(timeout time.Duration) error { // gracefully shutdown http.Server // close open listeners, idle connections // until context cancel or time-out - e.stopGRPCServer(gs) - return srv.Shutdown(ctx) + stopServers(&servers{secure: peerTLScfg != nil, grpc: gs, http: srv}, timeout) + return nil } } diff --git a/embed/serve.go b/embed/serve.go index 12af13cb884..2811aaf0641 100644 --- a/embed/serve.go +++ b/embed/serve.go @@ -54,13 +54,19 @@ type serveCtx struct { userHandlers map[string]http.Handler serviceRegister func(*grpc.Server) - grpcServerC chan *grpc.Server + serversC chan *servers +} + +type servers struct { + secure bool + grpc *grpc.Server + http *http.Server } func newServeCtx() *serveCtx { ctx, cancel := context.WithCancel(context.Background()) return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler), - grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true + serversC: make(chan *servers, 2), // in case sctx.insecure,sctx.secure true } } @@ -84,7 +90,6 @@ func (sctx *serveCtx) serve( if sctx.insecure { gs := v3rpc.Server(s, nil, gopts...) - sctx.grpcServerC <- gs v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { @@ -93,9 +98,7 @@ func (sctx *serveCtx) serve( grpcl := m.Match(cmux.HTTP2()) go func() { errHandler(gs.Serve(grpcl)) }() - opts := []grpc.DialOption{ - grpc.WithInsecure(), - } + opts := []grpc.DialOption{grpc.WithInsecure()} gwmux, err := sctx.registerGateway(opts) if err != nil { return err @@ -109,6 +112,8 @@ func (sctx *serveCtx) serve( } httpl := m.Match(cmux.HTTP1()) go func() { errHandler(srvhttp.Serve(httpl)) }() + + sctx.serversC <- &servers{grpc: gs, http: srvhttp} plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String()) } @@ -118,7 +123,6 @@ func (sctx *serveCtx) serve( return tlsErr } gs := v3rpc.Server(s, tlscfg, gopts...) - sctx.grpcServerC <- gs v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { @@ -150,10 +154,11 @@ func (sctx *serveCtx) serve( } go func() { errHandler(srv.Serve(tlsl)) }() + sctx.serversC <- &servers{secure: true, grpc: gs, http: srv} plog.Infof("serving client requests on %s", sctx.l.Addr().String()) } - close(sctx.grpcServerC) + close(sctx.serversC) return m.Serve() } diff --git a/integration/embed_test.go b/integration/embed_test.go index 2883ec71622..b5ae39e32b1 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -47,7 +47,7 @@ func TestEmbedEtcd(t *testing.T) { {werr: "expected IP"}, } - urls := newEmbedURLs(10) + urls := newEmbedURLs(false, 10) // setup defaults for i := range tests { @@ -105,12 +105,19 @@ func TestEmbedEtcd(t *testing.T) { } } -// TestEmbedEtcdGracefulStop ensures embedded server stops +func TestEmbedEtcdGracefulStopSecure(t *testing.T) { testEmbedEtcdGracefulStop(t, true) } +func TestEmbedEtcdGracefulStopInsecure(t *testing.T) { testEmbedEtcdGracefulStop(t, false) } + +// testEmbedEtcdGracefulStop ensures embedded server stops // cutting existing transports. -func TestEmbedEtcdGracefulStop(t *testing.T) { +func testEmbedEtcdGracefulStop(t *testing.T, secure bool) { cfg := embed.NewConfig() + if secure { + cfg.ClientTLSInfo = testTLSInfo + cfg.PeerTLSInfo = testTLSInfo + } - urls := newEmbedURLs(2) + urls := newEmbedURLs(secure, 2) setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]}) cfg.Dir = filepath.Join(os.TempDir(), fmt.Sprintf("embed-etcd")) @@ -123,7 +130,16 @@ func TestEmbedEtcdGracefulStop(t *testing.T) { } <-e.Server.ReadyNotify() // wait for e.Server to join the cluster - cli, err := clientv3.New(clientv3.Config{Endpoints: []string{urls[0].String()}}) + clientCfg := clientv3.Config{ + Endpoints: []string{urls[0].String()}, + } + if secure { + clientCfg.TLS, err = testTLSInfo.ClientConfig() + if err != nil { + t.Fatal(err) + } + } + cli, err := clientv3.New(clientCfg) if err != nil { t.Fatal(err) } @@ -146,9 +162,13 @@ func TestEmbedEtcdGracefulStop(t *testing.T) { } } -func newEmbedURLs(n int) (urls []url.URL) { +func newEmbedURLs(secure bool, n int) (urls []url.URL) { + scheme := "unix" + if secure { + scheme = "unixs" + } for i := 0; i < n; i++ { - u, _ := url.Parse(fmt.Sprintf("unix://localhost:%d%06d", os.Getpid(), i)) + u, _ := url.Parse(fmt.Sprintf("%s://localhost:%d%06d", scheme, os.Getpid(), i)) urls = append(urls, *u) } return urls