diff --git a/pkg/utils/etcdutil/etcdutil.go b/pkg/utils/etcdutil/etcdutil.go index c42ea0f895b8..ef987609f330 100644 --- a/pkg/utils/etcdutil/etcdutil.go +++ b/pkg/utils/etcdutil/etcdutil.go @@ -24,7 +24,6 @@ import ( "github.com/gogo/protobuf/proto" "github.com/pingcap/errors" - "github.com/pingcap/failpoint" "github.com/pingcap/log" "github.com/tikv/pd/pkg/errs" @@ -220,20 +219,17 @@ func createEtcdClient(tlsConfig *tls.Config, acUrls []url.URL) (*clientv3.Client lgc := zap.NewProductionConfig() lgc.Encoding = log.ZapEncodingName autoSyncInterval := defaultAutoSyncInterval - dialKeepAliveTime := defaultDialKeepAliveTime - dialKeepAliveTimeout := defaultDialKeepAliveTimeout failpoint.Inject("autoSyncInterval", func() { autoSyncInterval = 10 * time.Millisecond }) - client, err := clientv3.New(clientv3.Config{ Endpoints: endpoints, DialTimeout: defaultEtcdClientTimeout, AutoSyncInterval: autoSyncInterval, TLS: tlsConfig, LogConfig: &lgc, - DialKeepAliveTime: dialKeepAliveTime, - DialKeepAliveTimeout: dialKeepAliveTimeout, + DialKeepAliveTime: defaultDialKeepAliveTime, + DialKeepAliveTimeout: defaultDialKeepAliveTimeout, }) if err == nil { log.Info("create etcd v3 client", zap.Strings("endpoints", endpoints)) diff --git a/pkg/utils/etcdutil/etcdutil_test.go b/pkg/utils/etcdutil/etcdutil_test.go index 2769c3818b80..6e4976698c2e 100644 --- a/pkg/utils/etcdutil/etcdutil_test.go +++ b/pkg/utils/etcdutil/etcdutil_test.go @@ -249,7 +249,7 @@ func TestEtcdClientSync(t *testing.T) { require.NoError(t, failpoint.Disable("github.com/tikv/pd/pkg/utils/etcdutil/autoSyncInterval")) } -func TestEtcdWithDelayLeader(t *testing.T) { +func TestEtcdWithHangLeader(t *testing.T) { t.Parallel() re := require.New(t) // Start a etcd server. @@ -261,8 +261,8 @@ func TestEtcdWithDelayLeader(t *testing.T) { // Create a proxy to etcd1. proxyAddr := tempurl.Alloc() - var enableDelay atomic.Bool - go proxyWithDelay(re, ep1, proxyAddr, &enableDelay) + var enableDiscard atomic.Bool + go proxyWithDiscard(re, ep1, proxyAddr, &enableDiscard) // Create a etcd client with etcd1 as endpoint. urls, err := types.NewURLs([]string{proxyAddr}) @@ -270,14 +270,15 @@ func TestEtcdWithDelayLeader(t *testing.T) { client1, err := createEtcdClient(nil, urls) re.NoError(err) + // Add a new member and set the client endpoints to etcd1 and etcd2. etcd2 := checkAddEtcdMember(t, cfg1, client1) defer etcd2.Close() checkMembers(re, client1, []*embed.Etcd{etcd1, etcd2}) - etcd2Addr := etcd2.Config().LCUrls[0].String() client1.SetEndpoints(proxyAddr, etcd2Addr) - enableDelay.Store(true) + // Hang the etcd1 and wait for the client to connect to etcd2. + enableDiscard.Store(true) time.Sleep(defaultDialKeepAliveTime + defaultDialKeepAliveTimeout*2) _, err = EtcdKVGet(client1, "test/key1") re.NoError(err) @@ -317,7 +318,7 @@ func checkMembers(re *require.Assertions, client *clientv3.Client, etcds []*embe } } -func proxyWithDelay(re *require.Assertions, server, proxy string, enableDelay *atomic.Bool) { +func proxyWithDiscard(re *require.Assertions, server, proxy string, enableDiscard *atomic.Bool) { server = strings.TrimPrefix(server, "http://") proxy = strings.TrimPrefix(proxy, "http://") l, err := net.Listen("tcp", proxy) @@ -328,48 +329,44 @@ func proxyWithDelay(re *require.Assertions, server, proxy string, enableDelay *a go func(connect net.Conn) { serverConnect, err := net.Dial("tcp", server) re.NoError(err) - pipe(connect, serverConnect, enableDelay) + pipe(connect, serverConnect, enableDiscard) }(connect) } } -func pipe(src net.Conn, dst net.Conn, enableDelay *atomic.Bool) { +func pipe(src net.Conn, dst net.Conn, enableDiscard *atomic.Bool) { errChan := make(chan error, 1) - closePipe := func() { - dst.Close() - src.Close() - } go func() { - err := ioCopy(src, dst, enableDelay) + err := ioCopy(src, dst, enableDiscard) errChan <- err - closePipe() }() go func() { - err := ioCopy(dst, src, enableDelay) + err := ioCopy(dst, src, enableDiscard) errChan <- err - closePipe() }() <-errChan + dst.Close() + src.Close() } -func ioCopy(dst io.Writer, src io.Reader, enableDelay *atomic.Bool) (err error) { - buf := make([]byte, 32*1024) +func ioCopy(dst io.Writer, src io.Reader, enableDiscard *atomic.Bool) (err error) { + buffer := make([]byte, 32*1024) for { - if enableDelay.Load() { + if enableDiscard.Load() { io.Copy(io.Discard, src) } - nr, er := src.Read(buf) - if nr > 0 { - nw, ew := dst.Write(buf[0:nr]) - if ew != nil { - return ew + readNum, errRead := src.Read(buffer) + if readNum > 0 { + writeNum, errWrite := dst.Write(buffer[0:readNum]) + if errWrite != nil { + return errWrite } - if nr != nw { + if readNum != writeNum { return io.ErrShortWrite } } - if er != nil { - err = er + if errRead != nil { + err = errRead break } }