diff --git a/embed/etcd.go b/embed/etcd.go index 5857044d26ab..713bb61e8e9d 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -29,13 +29,16 @@ import ( "github.com/coreos/etcd/etcdserver" "github.com/coreos/etcd/etcdserver/api/etcdhttp" "github.com/coreos/etcd/etcdserver/api/v2http" + "github.com/coreos/etcd/etcdserver/api/v3rpc" "github.com/coreos/etcd/pkg/cors" "github.com/coreos/etcd/pkg/debugutil" runtimeutil "github.com/coreos/etcd/pkg/runtime" "github.com/coreos/etcd/pkg/transport" "github.com/coreos/etcd/pkg/types" "github.com/coreos/etcd/rafthttp" + "github.com/coreos/pkg/capnslog" + "github.com/soheilhy/cmux" "google.golang.org/grpc" "google.golang.org/grpc/keepalive" ) @@ -60,12 +63,14 @@ const ( type Etcd struct { Peers []*peerListener Clients []net.Listener - Server *etcdserver.EtcdServer + // a map of contexts for the servers that serves client requests. + sctxs map[string]*serveCtx + + Server *etcdserver.EtcdServer cfg Config stopc chan struct{} errc chan error - sctxs map[string]*serveCtx closeOnce sync.Once } @@ -91,9 +96,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { return } if !serving { - // errored before starting gRPC server for serveCtx.grpcServerC + // errored before starting gRPC server for serveCtx.serversC for _, sctx := range e.sctxs { - close(sctx.grpcServerC) + close(sctx.serversC) } } e.Close() @@ -101,10 +106,10 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { }() if e.Peers, err = startPeerListeners(cfg); err != nil { - return + return e, err } if e.sctxs, err = startClientListeners(cfg); err != nil { - return + return e, err } for _, sctx := range e.sctxs { e.Clients = append(e.Clients, sctx.l) @@ -150,37 +155,23 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) { } if e.Server, err = etcdserver.NewServer(srvcfg); err != nil { - return - } - - // configure peer handlers after rafthttp.Transport started - ph := etcdhttp.NewPeerHandler(e.Server) - for _, p := range e.Peers { - srv := &http.Server{ - Handler: ph, - ReadTimeout: 5 * time.Minute, - ErrorLog: defaultLog.New(ioutil.Discard, "", 0), // do not log user error - } - - l := p.Listener - p.serve = func() error { return srv.Serve(l) } - p.close = func(ctx context.Context) error { - // gracefully shutdown http.Server - // close open listeners, idle connections - // until context cancel or time-out - return srv.Shutdown(ctx) - } + return e, err } // buffer channel so goroutines on closed connections won't wait forever e.errc = make(chan error, len(e.Peers)+len(e.Clients)+2*len(e.sctxs)) e.Server.Start() - if err = e.serve(); err != nil { - return + + if err = e.servePeers(); err != nil { + return e, err } + if err = e.serveClients(); err != nil { + return e, err + } + serving = true - return + return e, nil } // Config returns the current configuration. @@ -188,38 +179,29 @@ func (e *Etcd) Config() Config { return e.cfg } +// Close gracefully shuts down all servers/listeners. +// Client requests will be terminated with request timeout. +// After timeout, enforce remaning requests be closed immediately. func (e *Etcd) Close() { e.closeOnce.Do(func() { close(e.stopc) }) + // close client requests with request timeout timeout := 2 * time.Second if e.Server != nil { timeout = e.Server.Cfg.ReqTimeout() } for _, sctx := range e.sctxs { - for gs := range sctx.grpcServerC { - ch := make(chan struct{}) - go func() { - defer close(ch) - // close listeners to stop accepting new connections, - // will block on any existing transports - gs.GracefulStop() - }() - // wait until all pending RPCs are finished - select { - case <-ch: - case <-time.After(timeout): - // took too long, manually close open transports - // e.g. watch streams - gs.Stop() - // concurrent GracefulStop should be interrupted - <-ch - } + for ss := range sctx.serversC { + ctx, cancel := context.WithTimeout(context.Background(), timeout) + stopServers(ctx, ss) + cancel() } } for _, sctx := range e.sctxs { sctx.cancel() } + for i := range e.Clients { if e.Clients[i] != nil { e.Clients[i].Close() @@ -241,6 +223,43 @@ func (e *Etcd) Close() { } } +func stopServers(ctx context.Context, ss *servers) { + shutdownNow := func() { + // first, close the http.Server + ss.http.Shutdown(ctx) + // then close grpc.Server; cancels all active RPCs + ss.grpc.Stop() + } + + // do not grpc.Server.GracefulStop with TLS enabled etcd server + // See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531 + // and https://github.com/coreos/etcd/issues/8916 + if ss.secure { + shutdownNow() + return + } + + ch := make(chan struct{}) + go func() { + defer close(ch) + // close listeners to stop accepting new connections, + // will block on any existing transports + ss.grpc.GracefulStop() + }() + + // wait until all pending RPCs are finished + select { + case <-ch: + case <-ctx.Done(): + // took too long, manually close open transports + // e.g. watch streams + shutdownNow() + + // concurrent GracefulStop should be interrupted + <-ch + } +} + func (e *Etcd) Err() <-chan error { return e.errc } func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { @@ -269,7 +288,9 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { for i := range peers { if peers[i] != nil && peers[i].close != nil { plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String()) - peers[i].close(context.Background()) + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + peers[i].close(ctx) + cancel() } } }() @@ -297,6 +318,45 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) { return peers, nil } +// configure peer handlers after rafthttp.Transport started +func (e *Etcd) servePeers() (err error) { + ph := etcdhttp.NewPeerHandler(e.Server) + var peerTLScfg *tls.Config + if !e.cfg.PeerTLSInfo.Empty() { + if peerTLScfg, err = e.cfg.PeerTLSInfo.ServerConfig(); err != nil { + return err + } + } + + for _, p := range e.Peers { + gs := v3rpc.Server(e.Server, peerTLScfg) + m := cmux.New(p.Listener) + go gs.Serve(m.Match(cmux.HTTP2())) + srv := &http.Server{ + Handler: grpcHandlerFunc(gs, ph), + ReadTimeout: 5 * time.Minute, + ErrorLog: defaultLog.New(ioutil.Discard, "", 0), // do not log user error + } + go srv.Serve(m.Match(cmux.Any())) + p.serve = func() error { return m.Serve() } + p.close = func(ctx context.Context) error { + // gracefully shutdown http.Server + // close open listeners, idle connections + // until context cancel or time-out + stopServers(ctx, &servers{secure: peerTLScfg != nil, grpc: gs, http: srv}) + return nil + } + } + + // start peer servers in a goroutine + for _, pl := range e.Peers { + go func(l *peerListener) { + e.errHandler(l.serve()) + }(pl) + } + return nil +} + func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) { if cfg.ClientAutoTLS && cfg.ClientTLSInfo.Empty() { chosts := make([]string, len(cfg.LCUrls)) @@ -388,7 +448,7 @@ func startClientListeners(cfg *Config) (sctxs map[string]*serveCtx, err error) { return sctxs, nil } -func (e *Etcd) serve() (err error) { +func (e *Etcd) serveClients() (err error) { var ctlscfg *tls.Config if !e.cfg.ClientTLSInfo.Empty() { plog.Infof("ClientTLS: %s", e.cfg.ClientTLSInfo) @@ -401,13 +461,6 @@ func (e *Etcd) serve() (err error) { plog.Infof("cors = %s", e.cfg.CorsInfo) } - // Start the peer server in a goroutine - for _, pl := range e.Peers { - go func(l *peerListener) { - e.errHandler(l.serve()) - }(pl) - } - // Start a client server goroutine for each listen address var h http.Handler if e.Config().EnableV2 { @@ -433,6 +486,8 @@ func (e *Etcd) serve() (err error) { Timeout: e.cfg.GRPCKeepAliveTimeout, })) } + + // start client servers in a goroutine for _, sctx := range e.sctxs { go func(s *serveCtx) { e.errHandler(s.serve(e.Server, ctlscfg, h, e.errHandler, gopts...)) diff --git a/embed/serve.go b/embed/serve.go index 3627f88a9587..b659bf8b7d68 100644 --- a/embed/serve.go +++ b/embed/serve.go @@ -53,13 +53,22 @@ type serveCtx struct { userHandlers map[string]http.Handler serviceRegister func(*grpc.Server) - grpcServerC chan *grpc.Server + serversC chan *servers +} + +type servers struct { + secure bool + grpc *grpc.Server + http *http.Server } func newServeCtx() *serveCtx { ctx, cancel := context.WithCancel(context.Background()) - return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler), - grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true + return &serveCtx{ + ctx: ctx, + cancel: cancel, + userHandlers: make(map[string]http.Handler), + serversC: make(chan *servers, 2), // in case sctx.insecure,sctx.secure true } } @@ -83,7 +92,6 @@ func (sctx *serveCtx) serve( if sctx.insecure { gs := v3rpc.Server(s, nil, gopts...) - sctx.grpcServerC <- gs v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { @@ -92,9 +100,7 @@ func (sctx *serveCtx) serve( grpcl := m.Match(cmux.HTTP2()) go func() { errHandler(gs.Serve(grpcl)) }() - opts := []grpc.DialOption{ - grpc.WithInsecure(), - } + opts := []grpc.DialOption{grpc.WithInsecure()} gwmux, err := sctx.registerGateway(opts) if err != nil { return err @@ -108,12 +114,13 @@ func (sctx *serveCtx) serve( } httpl := m.Match(cmux.HTTP1()) go func() { errHandler(srvhttp.Serve(httpl)) }() + + sctx.serversC <- &servers{grpc: gs, http: srvhttp} plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String()) } if sctx.secure { gs := v3rpc.Server(s, tlscfg, gopts...) - sctx.grpcServerC <- gs v3electionpb.RegisterElectionServer(gs, servElection) v3lockpb.RegisterLockServer(gs, servLock) if sctx.serviceRegister != nil { @@ -142,10 +149,11 @@ func (sctx *serveCtx) serve( } go func() { errHandler(srv.Serve(tlsl)) }() + sctx.serversC <- &servers{secure: true, grpc: gs, http: srv} plog.Infof("serving client requests on %s", sctx.l.Addr().String()) } - close(sctx.grpcServerC) + close(sctx.serversC) return m.Serve() }