diff --git a/embed/etcd.go b/embed/etcd.go index 5b7a17ca99a..c5c0d2d699f 100644 --- a/embed/etcd.go +++ b/embed/etcd.go @@ -192,11 +192,29 @@ func (e *Etcd) Config() Config { func (e *Etcd) Close() { e.closeOnce.Do(func() { close(e.stopc) }) - // (gRPC server) stops accepting new connections, - // RPCs, and blocks until all pending RPCs are finished + timeout := 2 * time.Second + if e.Server != nil { + timeout = e.Server.Cfg.ReqTimeout() + } for _, sctx := range e.sctxs { for gs := range sctx.grpcServerC { - gs.GracefulStop() + ch := make(chan struct{}) + go func() { + defer close(ch) + // close listeners to stop accepting new connections, + // will block on any existing transports + gs.GracefulStop() + }() + // wait until all pending RPCs are finished + select { + case <-ch: + case <-time.After(timeout): + // took too long, manually close open transports + // e.g. watch streams + gs.Stop() + // concurrent GracefulStop should be interrupted + <-ch + } } } diff --git a/integration/embed_test.go b/integration/embed_test.go index 8cba0b39605..751494eaea0 100644 --- a/integration/embed_test.go +++ b/integration/embed_test.go @@ -15,13 +15,16 @@ package integration import ( + "context" "fmt" "net/url" "os" "path/filepath" "strings" "testing" + "time" + "github.com/coreos/etcd/clientv3" "github.com/coreos/etcd/embed" ) @@ -102,6 +105,47 @@ func TestEmbedEtcd(t *testing.T) { } } +// TestEmbedEtcdGracefulStop ensures embedded server stops +// cutting existing transports. +func TestEmbedEtcdGracefulStop(t *testing.T) { + cfg := embed.NewConfig() + + urls := newEmbedURLs(2) + setupEmbedCfg(cfg, []url.URL{urls[0]}, []url.URL{urls[1]}) + + cfg.Dir = filepath.Join(os.TempDir(), fmt.Sprintf("embed-etcd")) + os.RemoveAll(cfg.Dir) + defer os.RemoveAll(cfg.Dir) + + e, err := embed.StartEtcd(cfg) + if err != nil { + t.Fatal(err) + } + <-e.Server.ReadyNotify() // wait for e.Server to join the cluster + + cli, err := clientv3.New(clientv3.Config{Endpoints: []string{urls[0].String()}}) + if err != nil { + t.Fatal(err) + } + defer cli.Close() + + // open watch connection + cli.Watch(context.Background(), "foo") + + donec := make(chan struct{}) + go func() { + e.Close() + close(donec) + }() + select { + case err := <-e.Err(): + t.Fatal(err) + case <-donec: + case <-time.After(2*time.Second + e.Server.Cfg.ReqTimeout()): + t.Fatalf("took too long to close server") + } +} + func newEmbedURLs(n int) (urls []url.URL) { for i := 0; i < n; i++ { u, _ := url.Parse(fmt.Sprintf("unix://localhost:%d%06d", os.Getpid(), i))