diff --git a/dht_test.go b/dht_test.go index 99ead547..22eb03cf 100644 --- a/dht_test.go +++ b/dht_test.go @@ -1493,6 +1493,11 @@ func TestInvalidServer(t *testing.T) { for _, m := range []*IpfsDHT{m0, m1} { // Hang on every request. m.host.SetStreamHandler(protocol, func(s network.Stream) { + select { + case <-ctx.Done(): + return + default: + } r := msgio.NewVarintReaderSize(s, network.MessageSizeMax) msgbytes, err := r.ReadMsg() if err != nil { @@ -1505,7 +1510,7 @@ func TestInvalidServer(t *testing.T) { } // answer with an empty response message - resp := pb.NewMessage(req.GetType(), nil, req.GetClusterLevel()) + resp := pb.NewMessage(req.GetType(), make([]byte, 32), req.GetClusterLevel()) // send out response msg err = net.WriteMsg(s, resp) @@ -1526,9 +1531,14 @@ func TestInvalidServer(t *testing.T) { time.Sleep(time.Millisecond * 5) // just in case... // find the provider for k from m0 - provs, err := m0.FindProviders(ctx, k) - if err != nil { - t.Fatal(err) + maxRetries := 3 + var provs []peer.AddrInfo + var err error + for i := 0; i < maxRetries && len(provs) == 0; i++ { + provs, err = m0.FindProviders(ctx, k) + if err != nil { + t.Fatal(err) + } } if len(provs) == 0 { t.Fatal("Expected to get a provider back") @@ -1555,7 +1565,10 @@ func TestInvalidServer(t *testing.T) { // contains more than bucketSize (2) entries, lookupCheck is enabled and m1 // shouldn't be added, because it fails the lookupCheck (hang on all requests). if s0.routingTable.Find(s1.self) == "" { - t.Fatal("Well behaving DHT server should have been added to the server routing table") + time.Sleep(time.Millisecond * 5) // just in case... + if s0.routingTable.Find(s1.self) == "" { + t.Fatal("Well behaving DHT server should have been added to the server routing table") + } } if s0.routingTable.Find(m1.self) != "" { t.Fatal("Misbehaving DHT servers should not be added to routing table if well populated")