diff --git a/p2p/host/basic/basic_host_test.go b/p2p/host/basic/basic_host_test.go index eeb41c4ed3..9d05dfa1b5 100644 --- a/p2p/host/basic/basic_host_test.go +++ b/p2p/host/basic/basic_host_test.go @@ -1,7 +1,6 @@ package basichost import ( - "bytes" "context" "fmt" "io" @@ -49,9 +48,7 @@ func TestHostSimple(t *testing.T) { defer h2.Close() h2pi := h2.Peerstore().PeerInfo(h2.ID()) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal(err) - } + require.NoError(t, h1.Connect(ctx, h2pi)) piper, pipew := io.Pipe() h2.SetStreamHandler(protocol.TestingID, func(s network.Stream) { @@ -61,33 +58,24 @@ func TestHostSimple(t *testing.T) { }) s, err := h1.NewStream(ctx, h2pi.ID, protocol.TestingID) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) // write to the stream buf1 := []byte("abcdefghijkl") - if _, err := s.Write(buf1); err != nil { - t.Fatal(err) - } + _, err = s.Write(buf1) + require.NoError(t, err) // get it from the stream (echoed) buf2 := make([]byte, len(buf1)) - if _, err := io.ReadFull(s, buf2); err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf1, buf2) { - t.Fatalf("buf1 != buf2 -- %x != %x", buf1, buf2) - } + _, err = io.ReadFull(s, buf2) + require.NoError(t, err) + require.Equal(t, buf1, buf2) // get it from the pipe (tee) buf3 := make([]byte, len(buf1)) - if _, err := io.ReadFull(piper, buf3); err != nil { - t.Fatal(err) - } - if !bytes.Equal(buf1, buf3) { - t.Fatalf("buf1 != buf3 -- %x != %x", buf1, buf3) - } + _, err = io.ReadFull(piper, buf3) + require.NoError(t, err) + require.Equal(t, buf1, buf3) } func TestMultipleClose(t *testing.T) { @@ -109,9 +97,7 @@ func TestSignedPeerRecordWithNoListenAddrs(t *testing.T) { } // now add a listen addr - if err := h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0")); err != nil { - t.Fatal(err) - } + require.NoError(t, h.Network().Listen(ma.StringCast("/ip4/0.0.0.0/tcp/0"))) if len(h.Addrs()) < 1 { t.Errorf("expected at least 1 listen addr, got %d", len(h.Addrs())) } @@ -135,9 +121,7 @@ func TestProtocolHandlerEvents(t *testing.T) { defer h.Close() sub, err := h.EventBus().Subscribe(&event.EvtLocalProtocolsUpdated{}, eventbus.BufSize(16)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) defer sub.Close() // the identify service adds new protocol handlers shortly after the host @@ -264,6 +248,8 @@ func TestAllAddrs(t *testing.T) { t.Fatal("expected addrs to contain original addr") } +// getHostPair gets a new pair of hosts. +// The first host initiates the connection to the second host. func getHostPair(t *testing.T) (host.Host, host.Host) { t.Helper() @@ -275,9 +261,7 @@ func getHostPair(t *testing.T) (host.Host, host.Host) { ctx, cancel := context.WithTimeout(context.Background(), time.Second) defer cancel() h2pi := h2.Peerstore().PeerInfo(h2.ID()) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal(err) - } + require.NoError(t, h1.Connect(ctx, h2pi)) return h1, h2 } @@ -301,70 +285,55 @@ func TestHostProtoPreference(t *testing.T) { defer h1.Close() defer h2.Close() - protoOld := protocol.ID("/testing") - protoNew := protocol.ID("/testing/1.1.0") - protoMinor := protocol.ID("/testing/1.2.0") + const ( + protoOld = protocol.ID("/testing") + protoNew = protocol.ID("/testing/1.1.0") + protoMinor = protocol.ID("/testing/1.2.0") + ) connectedOn := make(chan protocol.ID) - handler := func(s network.Stream) { connectedOn <- s.Protocol() s.Close() } // Prevent pushing identify information so this test works. - h2.RemoveStreamHandler(identify.IDPush) - h2.RemoveStreamHandler(identify.IDDelta) + h1.RemoveStreamHandler(identify.IDPush) + h1.RemoveStreamHandler(identify.IDDelta) - h1.SetStreamHandler(protoOld, handler) + h2.SetStreamHandler(protoOld, handler) - s, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) - if err != nil { - t.Fatal(err) - } + s, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld) + require.NoError(t, err) // force the lazy negotiation to complete _, err = s.Write(nil) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assertWait(t, connectedOn, protoOld) s.Close() mfunc, err := helpers.MultistreamSemverMatcher(protoMinor) - if err != nil { - t.Fatal(err) - } - - h1.SetStreamHandlerMatch(protoMinor, mfunc, handler) + require.NoError(t, err) + h2.SetStreamHandlerMatch(protoMinor, mfunc, handler) // remembered preference will be chosen first, even when the other side newly supports it - s2, err := h2.NewStream(ctx, h1.ID(), protoMinor, protoNew, protoOld) - if err != nil { - t.Fatal(err) - } + s2, err := h1.NewStream(ctx, h2.ID(), protoMinor, protoNew, protoOld) + require.NoError(t, err) // required to force 'lazy' handshake _, err = s2.Write([]byte("hello")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assertWait(t, connectedOn, protoOld) - s2.Close() - s3, err := h2.NewStream(ctx, h1.ID(), protoMinor) - if err != nil { - t.Fatal(err) - } + s3, err := h1.NewStream(ctx, h2.ID(), protoMinor) + require.NoError(t, err) // Force a lazy handshake as we may have received a protocol update by this point. _, err = s3.Write([]byte("hello")) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) assertWait(t, connectedOn, protoMinor) s3.Close() @@ -397,6 +366,8 @@ func TestHostProtoPreknowledge(t *testing.T) { require.NoError(t, err) h2, err := NewHost(swarmt.GenSwarm(t), nil) require.NoError(t, err) + defer h1.Close() + defer h2.Close() conn := make(chan protocol.ID) handler := func(s network.Stream) { @@ -404,18 +375,13 @@ func TestHostProtoPreknowledge(t *testing.T) { s.Close() } - h1.SetStreamHandler("/super", handler) - + h2.SetStreamHandler("/super", handler) // Prevent pushing identify information so this test actually _uses_ the super protocol. - h2.RemoveStreamHandler(identify.IDPush) - h2.RemoveStreamHandler(identify.IDDelta) + h1.RemoveStreamHandler(identify.IDPush) + h1.RemoveStreamHandler(identify.IDDelta) h2pi := h2.Peerstore().PeerInfo(h2.ID()) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal(err) - } - defer h1.Close() - defer h2.Close() + require.NoError(t, h1.Connect(ctx, h2pi)) // wait for identify handshake to finish completely select { @@ -430,12 +396,10 @@ func TestHostProtoPreknowledge(t *testing.T) { t.Fatal("timed out waiting for identify") } - h1.SetStreamHandler("/foo", handler) + h2.SetStreamHandler("/foo", handler) - s, err := h2.NewStream(ctx, h1.ID(), "/foo", "/bar", "/super") - if err != nil { - t.Fatal(err) - } + s, err := h1.NewStream(ctx, h2.ID(), "/foo", "/bar", "/super") + require.NoError(t, err) select { case p := <-conn: @@ -444,10 +408,7 @@ func TestHostProtoPreknowledge(t *testing.T) { } _, err = s.Read(nil) - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) assertWait(t, conn, "/super") s.Close() @@ -462,29 +423,20 @@ func TestNewDialOld(t *testing.T) { defer h2.Close() connectedOn := make(chan protocol.ID) - h1.SetStreamHandler("/testing", func(s network.Stream) { + h2.SetStreamHandler("/testing", func(s network.Stream) { connectedOn <- s.Protocol() s.Close() }) - s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") - if err != nil { - t.Fatal(err) - } + s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing") + require.NoError(t, err) // force the lazy negotiation to complete _, err = s.Write(nil) - if err != nil { - t.Fatal(err) - } - + require.NoError(t, err) assertWait(t, connectedOn, "/testing") - if s.Protocol() != "/testing" { - t.Fatal("should have gotten /testing") - } - - s.Close() + require.Equal(t, s.Protocol(), protocol.ID("/testing"), "should have gotten /testing") } func TestProtoDowngrade(t *testing.T) { @@ -496,51 +448,32 @@ func TestProtoDowngrade(t *testing.T) { defer h2.Close() connectedOn := make(chan protocol.ID) - h1.SetStreamHandler("/testing/1.0.0", func(s network.Stream) { + h2.SetStreamHandler("/testing/1.0.0", func(s network.Stream) { defer s.Close() result, err := ioutil.ReadAll(s) - if err != nil { - t.Error(err) - } else if string(result) != "bar" { - t.Error("wrong result") - } + assert.NoError(t, err) + assert.Equal(t, string(result), "bar") connectedOn <- s.Protocol() }) - s, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") - if err != nil { - t.Fatal(err) - } - - if s.Protocol() != "/testing/1.0.0" { - t.Fatalf("should have gotten /testing/1.0.0, got %s", s.Protocol()) - } + s, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing") + require.NoError(t, err) + require.Equal(t, s.Protocol(), protocol.ID("/testing/1.0.0"), "should have gotten /testing/1.0.0, got %s", s.Protocol()) _, err = s.Write([]byte("bar")) - if err != nil { - t.Fatal(err) - } - err = s.CloseWrite() - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) + require.NoError(t, s.CloseWrite()) assertWait(t, connectedOn, "/testing/1.0.0") - if err := s.Close(); err != nil { - t.Error(err) - } - - h2.Network().ClosePeer(h1.ID()) + require.NoError(t, s.Close()) - h1.RemoveStreamHandler("/testing/1.0.0") - h1.SetStreamHandler("/testing", func(s network.Stream) { + h1.Network().ClosePeer(h2.ID()) + h2.RemoveStreamHandler("/testing/1.0.0") + h2.SetStreamHandler("/testing", func(s network.Stream) { defer s.Close() result, err := ioutil.ReadAll(s) - if err != nil { - t.Error(err) - } else if string(result) != "foo" { - t.Error("wrong result") - } + assert.NoError(t, err) + assert.Equal(t, string(result), "foo") connectedOn <- s.Protocol() }) @@ -549,45 +482,24 @@ func TestProtoDowngrade(t *testing.T) { time.Sleep(time.Millisecond) h2pi := h2.Peerstore().PeerInfo(h2.ID()) - if err := h1.Connect(ctx, h2pi); err != nil { - t.Fatal(err) - } + require.NoError(t, h1.Connect(ctx, h2pi)) - s2, err := h2.NewStream(ctx, h1.ID(), "/testing/1.0.0", "/testing") - if err != nil { - t.Fatal(err) - } - - if s2.Protocol() != "/testing" { - t.Errorf("should have gotten /testing, got %s, %s", s.Protocol(), s.Conn()) - } + s2, err := h1.NewStream(ctx, h2.ID(), "/testing/1.0.0", "/testing") + require.NoError(t, err) + require.Equal(t, s2.Protocol(), protocol.ID("/testing"), "should have gotten /testing, got %s, %s", s.Protocol(), s.Conn()) _, err = s2.Write([]byte("foo")) - if err != nil { - t.Error(err) - } - err = s2.CloseWrite() - if err != nil { - t.Error(err) - } + require.NoError(t, err) + require.NoError(t, s2.CloseWrite()) assertWait(t, connectedOn, "/testing") - if err := s.Close(); err != nil { - t.Error(err) - } } func TestAddrResolution(t *testing.T) { ctx := context.Background() - p1, err := test.RandPeerID() - if err != nil { - t.Error(err) - } - p2, err := test.RandPeerID() - if err != nil { - t.Error(err) - } + p1 := test.RandPeerIDFatal(t) + p2 := test.RandPeerIDFatal(t) addr1 := ma.StringCast("/dnsaddr/example.com") addr2 := ma.StringCast("/ip4/192.0.2.1/tcp/123") p2paddr1 := ma.StringCast("/dnsaddr/example.com/p2p/" + p1.Pretty()) @@ -600,18 +512,14 @@ func TestAddrResolution(t *testing.T) { }}, } resolver, err := madns.NewResolver(madns.WithDefaultResolver(backend)) - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) h, err := NewHost(swarmt.GenSwarm(t), &HostOpts{MultiaddrResolver: resolver}) require.NoError(t, err) defer h.Close() pi, err := peer.AddrInfoFromP2pAddr(p2paddr1) - if err != nil { - t.Error(err) - } + require.NoError(t, err) tctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) defer cancel() @@ -855,7 +763,7 @@ func TestNegotiationCancel(t *testing.T) { defer h2.Close() // pre-negotiation so we can make the negotiation hang. - h1.Network().SetStreamHandler(func(s network.Stream) { + h2.Network().SetStreamHandler(func(s network.Stream) { <-ctx.Done() // wait till the test is done. s.Reset() }) @@ -865,7 +773,7 @@ func TestNegotiationCancel(t *testing.T) { errCh := make(chan error, 1) go func() { - s, err := h2.NewStream(ctx2, h1.ID(), "/testing") + s, err := h1.NewStream(ctx2, h2.ID(), "/testing") if s != nil { errCh <- fmt.Errorf("expected to fail negotiation") return