Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid panic on shut down when TLS configuration is present #8986

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 42 additions & 20 deletions embed/etcd.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ type Etcd struct {
type peerListener struct {
net.Listener
serve func() error
close func(context.Context) error
close func(time.Duration) error
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why do we need this change. can you explain?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Previous (e *Etcd) stopGRPCServer(gs *grpc.Server) method was using embed.Etcd only to compute request timeout. Now, we want to use this same function with *http.Server parameter. Especially, in line https://github.com/coreos/etcd/pull/8986/files#diff-d7a47eb75475dba263540eb7f8456e50R258.

If we have time.Duration as an argument, it's easier to use this function for all shutdown calls.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

http server shutdown takes a context: https://golang.org/pkg/net/http/#Server.Shutdown

we should do the same, no?

}

// StartEtcd launches the etcd server and HTTP handlers for client/server communication.
Expand All @@ -100,9 +100,9 @@ func StartEtcd(inCfg *Config) (e *Etcd, err error) {
return
}
if !serving {
// errored before starting gRPC server for serveCtx.grpcServerC
// errored before starting gRPC server for serveCtx.serversC
for _, sctx := range e.sctxs {
close(sctx.grpcServerC)
close(sctx.serversC)
}
}
e.Close()
Expand Down Expand Up @@ -219,23 +219,30 @@ func (e *Etcd) Config() Config {
return e.cfg
}

// Close gracefully shuts down all servers/listeners.
func (e *Etcd) Close() {
e.closeOnce.Do(func() { close(e.stopc) })

timeout := 2 * time.Second
if e.Server != nil {
timeout = e.Server.Cfg.ReqTimeout()
}
for _, sctx := range e.sctxs {
for gs := range sctx.grpcServerC {
e.stopGRPCServer(gs)
for ss := range sctx.serversC {
stopServers(ss, timeout)
}
}

for _, sctx := range e.sctxs {
sctx.cancel()
}

for i := range e.Clients {
if e.Clients[i] != nil {
e.Clients[i].Close()
}
}

for i := range e.metricsListeners {
e.metricsListeners[i].Close()
}
Expand All @@ -248,32 +255,46 @@ func (e *Etcd) Close() {
// close all idle connections in peer handler (wait up to 1-second)
for i := range e.Peers {
if e.Peers[i] != nil && e.Peers[i].close != nil {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
e.Peers[i].close(ctx)
cancel()
e.Peers[i].close(time.Second)
}
}
}

func (e *Etcd) stopGRPCServer(gs *grpc.Server) {
timeout := 2 * time.Second
if e.Server != nil {
timeout = e.Server.Cfg.ReqTimeout()
func stopServers(ss *servers, timeout time.Duration) {
shutdownNow := func() {
// first, close the http.Server
ctx, cancel := context.WithTimeout(context.Background(), timeout)
ss.http.Shutdown(ctx)
cancel()

// and then close grpc.Server; cancels all active RPCs
ss.grpc.Stop()
}

// do not grpc.Server.GracefulStop with TLS enabled etcd server
// See https://github.com/grpc/grpc-go/issues/1384#issuecomment-317124531
// and https://github.com/coreos/etcd/issues/8916
if ss.secure {
shutdownNow()
return
}

ch := make(chan struct{})
go func() {
defer close(ch)
// close listeners to stop accepting new connections,
// will block on any existing transports
gs.GracefulStop()
ss.grpc.GracefulStop()
}()

// wait until all pending RPCs are finished
select {
case <-ch:
case <-time.After(timeout):
// took too long, manually close open transports
// e.g. watch streams
gs.Stop()
shutdownNow()

// concurrent GracefulStop should be interrupted
<-ch
}
Expand All @@ -297,7 +318,7 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
for i := range peers {
if peers[i] != nil && peers[i].close != nil {
plog.Info("stopping listening for peers on ", cfg.LPUrls[i].String())
peers[i].close(context.Background())
peers[i].close(time.Second)
}
}
}()
Expand All @@ -311,13 +332,13 @@ func startPeerListeners(cfg *Config) (peers []*peerListener, err error) {
plog.Warningf("The scheme of peer url %s is HTTP while client cert auth (--peer-client-cert-auth) is enabled. Ignored client cert auth for this url.", u.String())
}
}
peers[i] = &peerListener{close: func(context.Context) error { return nil }}
peers[i] = &peerListener{close: func(time.Duration) error { return nil }}
peers[i].Listener, err = rafthttp.NewListener(u, &cfg.PeerTLSInfo)
if err != nil {
return nil, err
}
// once serve, overwrite with 'http.Server.Shutdown'
peers[i].close = func(context.Context) error {
peers[i].close = func(time.Duration) error {
return peers[i].Listener.Close()
}
plog.Info("listening for peers on ", u.String())
Expand All @@ -334,6 +355,7 @@ func (e *Etcd) servePeers() (err error) {
return err
}
}

for _, p := range e.Peers {
gs := v3rpc.Server(e.Server, peerTLScfg)
m := cmux.New(p.Listener)
Expand All @@ -345,12 +367,12 @@ func (e *Etcd) servePeers() (err error) {
}
go srv.Serve(m.Match(cmux.Any()))
p.serve = func() error { return m.Serve() }
p.close = func(ctx context.Context) error {
p.close = func(timeout time.Duration) error {
// gracefully shutdown http.Server
// close open listeners, idle connections
// until context cancel or time-out
e.stopGRPCServer(gs)
return srv.Shutdown(ctx)
stopServers(&servers{secure: peerTLScfg != nil, grpc: gs, http: srv}, timeout)
return nil
}
}

Expand Down
21 changes: 13 additions & 8 deletions embed/serve.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,19 @@ type serveCtx struct {

userHandlers map[string]http.Handler
serviceRegister func(*grpc.Server)
grpcServerC chan *grpc.Server
serversC chan *servers
}

type servers struct {
secure bool
grpc *grpc.Server
http *http.Server
}

func newServeCtx() *serveCtx {
ctx, cancel := context.WithCancel(context.Background())
return &serveCtx{ctx: ctx, cancel: cancel, userHandlers: make(map[string]http.Handler),
grpcServerC: make(chan *grpc.Server, 2), // in case sctx.insecure,sctx.secure true
serversC: make(chan *servers, 2), // in case sctx.insecure,sctx.secure true
}
}

Expand All @@ -84,7 +90,6 @@ func (sctx *serveCtx) serve(

if sctx.insecure {
gs := v3rpc.Server(s, nil, gopts...)
sctx.grpcServerC <- gs
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand All @@ -93,9 +98,7 @@ func (sctx *serveCtx) serve(
grpcl := m.Match(cmux.HTTP2())
go func() { errHandler(gs.Serve(grpcl)) }()

opts := []grpc.DialOption{
grpc.WithInsecure(),
}
opts := []grpc.DialOption{grpc.WithInsecure()}
gwmux, err := sctx.registerGateway(opts)
if err != nil {
return err
Expand All @@ -109,6 +112,8 @@ func (sctx *serveCtx) serve(
}
httpl := m.Match(cmux.HTTP1())
go func() { errHandler(srvhttp.Serve(httpl)) }()

sctx.serversC <- &servers{grpc: gs, http: srvhttp}
plog.Noticef("serving insecure client requests on %s, this is strongly discouraged!", sctx.l.Addr().String())
}

Expand All @@ -118,7 +123,6 @@ func (sctx *serveCtx) serve(
return tlsErr
}
gs := v3rpc.Server(s, tlscfg, gopts...)
sctx.grpcServerC <- gs
v3electionpb.RegisterElectionServer(gs, servElection)
v3lockpb.RegisterLockServer(gs, servLock)
if sctx.serviceRegister != nil {
Expand Down Expand Up @@ -150,10 +154,11 @@ func (sctx *serveCtx) serve(
}
go func() { errHandler(srv.Serve(tlsl)) }()

sctx.serversC <- &servers{secure: true, grpc: gs, http: srv}
plog.Infof("serving client requests on %s", sctx.l.Addr().String())
}

close(sctx.grpcServerC)
close(sctx.serversC)
return m.Serve()
}

Expand Down
34 changes: 27 additions & 7 deletions integration/embed_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func TestEmbedEtcd(t *testing.T) {
{werr: "expected IP"},
}

urls := newEmbedURLs(10)
urls := newEmbedURLs(false, 10)

// setup defaults
for i := range tests {
Expand Down Expand Up @@ -105,12 +105,19 @@ func TestEmbedEtcd(t *testing.T) {
}
}

// TestEmbedEtcdGracefulStop ensures embedded server stops
func TestEmbedEtcdGracefulStopSecure(t *testing.T) { testEmbedEtcdGracefulStop(t, true) }
func TestEmbedEtcdGracefulStopInsecure(t *testing.T) { testEmbedEtcdGracefulStop(t, false) }

// testEmbedEtcdGracefulStop ensures embedded server stops
// cutting existing transports.
func TestEmbedEtcdGracefulStop(t *testing.T) {
func testEmbedEtcdGracefulStop(t *testing.T, secure bool) {
cfg := embed.NewConfig()
if secure {
cfg.ClientTLSInfo = testTLSInfo
cfg.PeerTLSInfo = testTLSInfo
}

urls := newEmbedURLs(2)
urls := newEmbedURLs(secure, 2)
setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]})

cfg.Dir = filepath.Join(os.TempDir(), fmt.Sprintf("embed-etcd"))
Expand All @@ -123,7 +130,16 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
}
<-e.Server.ReadyNotify() // wait for e.Server to join the cluster

cli, err := clientv3.New(clientv3.Config{Endpoints: []string{urls[0].String()}})
clientCfg := clientv3.Config{
Endpoints: []string{urls[0].String()},
}
if secure {
clientCfg.TLS, err = testTLSInfo.ClientConfig()
if err != nil {
t.Fatal(err)
}
}
cli, err := clientv3.New(clientCfg)
if err != nil {
t.Fatal(err)
}
Expand All @@ -146,9 +162,13 @@ func TestEmbedEtcdGracefulStop(t *testing.T) {
}
}

func newEmbedURLs(n int) (urls []url.URL) {
func newEmbedURLs(secure bool, n int) (urls []url.URL) {
scheme := "unix"
if secure {
scheme = "unixs"
}
for i := 0; i < n; i++ {
u, _ := url.Parse(fmt.Sprintf("unix://localhost:%d%06d", os.Getpid(), i))
u, _ := url.Parse(fmt.Sprintf("%s://localhost:%d%06d", scheme, os.Getpid(), i))
urls = append(urls, *u)
}
return urls
Expand Down