From b3a07d8b91ccd79f2679fd351c70b4a2d1455d88 Mon Sep 17 00:00:00 2001 From: Dhruba Basu <7675102+dhrubabasu@users.noreply.github.com> Date: Wed, 3 May 2023 13:34:55 -0400 Subject: [PATCH] Use `require.IsType` for type assertions in tests (#1458) --- api/auth/auth_test.go | 4 +- database/manager/manager_test.go | 18 ++--- indexer/indexer_test.go | 39 +++++---- message/inbound_msg_builder_test.go | 52 ++++++------ message/messages_test.go | 4 +- .../throttling/bandwidth_throttler_test.go | 4 +- .../inbound_resource_throttler_test.go | 4 +- .../consensus/snowball/unary_snowball_test.go | 10 ++- .../snowball/unary_snowflake_test.go | 10 ++- snow/engine/avalanche/getter/getter_test.go | 18 +++-- .../snowman/bootstrap/bootstrapper_test.go | 24 +++--- snow/engine/snowman/getter/getter_test.go | 12 +-- snow/engine/snowman/syncer/utils_test.go | 4 +- snow/networking/sender/sender_test.go | 40 +++++----- .../tracker/resource_tracker_test.go | 4 +- snow/networking/tracker/targeter_test.go | 4 +- utils/buffer/unbounded_deque_test.go | 18 ++--- utils/dynamicip/updater_test.go | 4 +- utils/window/window_test.go | 4 +- vms/avm/blocks/block_test.go | 4 +- vms/avm/blocks/builder/builder_test.go | 16 ++-- vms/avm/network/network_test.go | 4 +- vms/components/chain/state_test.go | 21 +++-- vms/components/message/message_test.go | 4 +- vms/components/message/tx_test.go | 4 +- vms/platformvm/blocks/builder/builder_test.go | 4 +- .../blocks/executor/manager_test.go | 8 +- vms/platformvm/blocks/parse_test.go | 27 +++---- vms/platformvm/service_test.go | 3 +- .../txs/executor/reward_validator_test.go | 20 ++--- vms/platformvm/vm_test.go | 60 +++++--------- vms/proposervm/batched_vm_test.go | 24 ++---- vms/proposervm/post_fork_block_test.go | 30 +++---- vms/proposervm/post_fork_option_test.go | 68 +++++++--------- vms/proposervm/pre_fork_block_test.go | 48 +++++------ vms/proposervm/vm_byzantine_test.go | 19 ++--- vms/proposervm/vm_test.go | 43 +++++----- vms/rpcchainvm/batched_vm_test.go | 17 ++-- vms/secp256k1fx/keychain_test.go | 12 +-- x/merkledb/db_test.go | 24 +++--- x/merkledb/trie_test.go | 80 +++++++++---------- 41 files changed, 379 insertions(+), 438 deletions(-) diff --git a/api/auth/auth_test.go b/api/auth/auth_test.go index c6a92585e15..6728cc62028 100644 --- a/api/auth/auth_test.go +++ b/api/auth/auth_test.go @@ -69,8 +69,8 @@ func TestNewTokenHappyPath(t *testing.T) { }) require.NoError(t, err, "couldn't parse new token") - claims, ok := token.Claims.(*endpointClaims) - require.True(t, ok, "expected auth token's claims to be type endpointClaims but is different type") + require.IsType(t, &endpointClaims{}, token.Claims) + claims := token.Claims.(*endpointClaims) require.ElementsMatch(t, endpoints, claims.Endpoints, "token has wrong endpoint claims") shouldExpireAt := jwt.NewNumericDate(now.Add(defaultTokenLifespan)) diff --git a/database/manager/manager_test.go b/database/manager/manager_test.go index e292024d6c5..0a24753f04c 100644 --- a/database/manager/manager_test.go +++ b/database/manager/manager_test.go @@ -306,12 +306,9 @@ func TestMeterDBManager(t *testing.T) { dbs := manager.GetDatabases() require.Len(dbs, 3) - _, ok := dbs[0].Database.(*meterdb.Database) - require.True(ok) - _, ok = dbs[1].Database.(*meterdb.Database) - require.False(ok) - _, ok = dbs[2].Database.(*meterdb.Database) - require.False(ok) + require.IsType(&meterdb.Database{}, dbs[0].Database) + require.IsType(&memdb.Database{}, dbs[1].Database) + require.IsType(&memdb.Database{}, dbs[2].Database) // Confirm that the error from a name conflict is handled correctly _, err = m.NewMeterDBManager("", registry) @@ -355,12 +352,9 @@ func TestCompleteMeterDBManager(t *testing.T) { dbs := manager.GetDatabases() require.Len(dbs, 3) - _, ok := dbs[0].Database.(*meterdb.Database) - require.True(ok) - _, ok = dbs[1].Database.(*meterdb.Database) - require.True(ok) - _, ok = dbs[2].Database.(*meterdb.Database) - require.True(ok) + require.IsType(&meterdb.Database{}, dbs[0].Database) + require.IsType(&meterdb.Database{}, dbs[1].Database) + require.IsType(&meterdb.Database{}, dbs[2].Database) // Confirm that the error from a name conflict is handled correctly _, err = m.NewCompleteMeterDBManager("", registry) diff --git a/indexer/indexer_test.go b/indexer/indexer_test.go index 6ad1784046f..5a5bb912f01 100644 --- a/indexer/indexer_test.go +++ b/indexer/indexer_test.go @@ -67,8 +67,8 @@ func TestNewIndexer(t *testing.T) { idxrIntf, err := NewIndexer(config) require.NoError(err) - idxr, ok := idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr := idxrIntf.(*indexer) require.NotNil(idxr.codec) require.NotNil(idxr.log) require.NotNil(idxr.db) @@ -118,8 +118,8 @@ func TestMarkHasRunAndShutdown(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - idxr, ok := idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr := idxrIntf.(*indexer) require.True(idxr.hasRunBefore) require.NoError(idxr.Close()) shutdown.Wait() @@ -150,8 +150,8 @@ func TestIndexer(t *testing.T) { // Create indexer idxrIntf, err := NewIndexer(config) require.NoError(err) - idxr, ok := idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr := idxrIntf.(*indexer) now := time.Now() idxr.clock.Set(now) @@ -232,10 +232,10 @@ func TestIndexer(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - idxr, ok = idxrIntf.(*indexer) + require.IsType(&indexer{}, idxrIntf) + idxr = idxrIntf.(*indexer) now = time.Now() idxr.clock.Set(now) - require.True(ok) require.Len(idxr.blockIndices, 0) require.Len(idxr.txIndices, 0) require.Len(idxr.vtxIndices, 0) @@ -389,8 +389,8 @@ func TestIndexer(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - idxr, ok = idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr = idxrIntf.(*indexer) idxr.RegisterChain("chain1", chain1Ctx, chainVM) idxr.RegisterChain("chain2", chain2Ctx, dagVM) @@ -427,8 +427,8 @@ func TestIncompleteIndex(t *testing.T) { } idxrIntf, err := NewIndexer(config) require.NoError(err) - idxr, ok := idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr := idxrIntf.(*indexer) require.False(idxr.indexingEnabled) // Register a chain @@ -454,8 +454,8 @@ func TestIncompleteIndex(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - idxr, ok = idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr = idxrIntf.(*indexer) require.True(idxr.indexingEnabled) // Register the chain again. Should die due to incomplete index. @@ -470,8 +470,8 @@ func TestIncompleteIndex(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - idxr, ok = idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr = idxrIntf.(*indexer) require.True(idxr.allowIncompleteIndex) // Register the chain again. Should be OK @@ -486,8 +486,7 @@ func TestIncompleteIndex(t *testing.T) { config.DB = versiondb.New(baseDB) idxrIntf, err = NewIndexer(config) require.NoError(err) - _, ok = idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) } // Ensure we only index chains in the primary network @@ -513,8 +512,8 @@ func TestIgnoreNonDefaultChains(t *testing.T) { // Create indexer idxrIntf, err := NewIndexer(config) require.NoError(err) - idxr, ok := idxrIntf.(*indexer) - require.True(ok) + require.IsType(&indexer{}, idxrIntf) + idxr := idxrIntf.(*indexer) // Assert state is right chain1Ctx := snow.DefaultConsensusContextTest() diff --git a/message/inbound_msg_builder_test.go b/message/inbound_msg_builder_test.go index 667a205d1df..068ce857f90 100644 --- a/message/inbound_msg_builder_test.go +++ b/message/inbound_msg_builder_test.go @@ -65,8 +65,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.GetStateSummaryFrontier) - require.True(ok) + require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) }, @@ -87,8 +87,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(StateSummaryFrontierOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.StateSummaryFrontier) - require.True(ok) + require.IsType(&p2p.StateSummaryFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.StateSummaryFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(summary, innerMsg.Summary) @@ -114,8 +114,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.GetAcceptedStateSummary) - require.True(ok) + require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(heights, innerMsg.Heights) @@ -137,8 +137,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(AcceptedStateSummaryOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.AcceptedStateSummary) - require.True(ok) + require.IsType(&p2p.AcceptedStateSummary{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AcceptedStateSummary) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) summaryIDsBytes := make([][]byte, len(summaryIDs)) @@ -169,8 +169,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.GetAcceptedFrontier) - require.True(ok) + require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAcceptedFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(engineType, innerMsg.EngineType) @@ -192,8 +192,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(AcceptedFrontierOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.AcceptedFrontier) - require.True(ok) + require.IsType(&p2p.AcceptedFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AcceptedFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) containerIDsBytes := make([][]byte, len(containerIDs)) @@ -225,8 +225,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.GetAccepted) - require.True(ok) + require.IsType(&p2p.GetAccepted{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAccepted) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(engineType, innerMsg.EngineType) @@ -248,8 +248,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(AcceptedOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.Accepted) - require.True(ok) + require.IsType(&p2p.Accepted{}, msg.Message()) + innerMsg := msg.Message().(*p2p.Accepted) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) containerIDsBytes := make([][]byte, len(containerIDs)) @@ -281,8 +281,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.PushQuery) - require.True(ok) + require.IsType(&p2p.PushQuery{}, msg.Message()) + innerMsg := msg.Message().(*p2p.PushQuery) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(container, innerMsg.Container) @@ -310,8 +310,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.PullQuery) - require.True(ok) + require.IsType(&p2p.PullQuery{}, msg.Message()) + innerMsg := msg.Message().(*p2p.PullQuery) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(containerIDs[0][:], innerMsg.ContainerId) @@ -335,8 +335,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(ChitsOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.Chits) - require.True(ok) + require.IsType(&p2p.Chits{}, msg.Message()) + innerMsg := msg.Message().(*p2p.Chits) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) containerIDsBytes := make([][]byte, len(containerIDs)) @@ -373,8 +373,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(nodeID, msg.NodeID()) require.False(msg.Expiration().Before(start.Add(deadline))) require.False(end.Add(deadline).Before(msg.Expiration())) - innerMsg, ok := msg.Message().(*p2p.AppRequest) - require.True(ok) + require.IsType(&p2p.AppRequest{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AppRequest) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(appBytes, innerMsg.AppBytes) @@ -396,8 +396,8 @@ func TestInboundMsgBuilder(t *testing.T) { require.Equal(AppResponseOp, msg.Op()) require.Equal(nodeID, msg.NodeID()) require.Equal(mockable.MaxTime, msg.Expiration()) - innerMsg, ok := msg.Message().(*p2p.AppResponse) - require.True(ok) + require.IsType(&p2p.AppResponse{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AppResponse) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(appBytes, innerMsg.AppBytes) diff --git a/message/messages_test.go b/message/messages_test.go index c04e3ea44dd..ead180ffaff 100644 --- a/message/messages_test.go +++ b/message/messages_test.go @@ -872,7 +872,7 @@ func TestNilInboundMessage(t *testing.T) { parsedMsg, err := mb.parseInbound(msgBytes, ids.EmptyNodeID, func() {}) require.NoError(err) - pingMsg, ok := parsedMsg.message.(*p2p.Ping) - require.True(ok) + require.IsType(&p2p.Ping{}, parsedMsg.message) + pingMsg := parsedMsg.message.(*p2p.Ping) require.NotNil(pingMsg) } diff --git a/network/throttling/bandwidth_throttler_test.go b/network/throttling/bandwidth_throttler_test.go index b4955959eed..f2a5e094b72 100644 --- a/network/throttling/bandwidth_throttler_test.go +++ b/network/throttling/bandwidth_throttler_test.go @@ -25,8 +25,8 @@ func TestBandwidthThrottler(t *testing.T) { } throttlerIntf, err := newBandwidthThrottler(logging.NoLog{}, "", prometheus.NewRegistry(), config) require.NoError(err) - throttler, ok := throttlerIntf.(*bandwidthThrottlerImpl) - require.True(ok) + require.IsType(&bandwidthThrottlerImpl{}, throttlerIntf) + throttler := throttlerIntf.(*bandwidthThrottlerImpl) require.NotNil(throttler.log) require.NotNil(throttler.limiters) require.Equal(config.RefillRate, throttler.RefillRate) diff --git a/network/throttling/inbound_resource_throttler_test.go b/network/throttling/inbound_resource_throttler_test.go index 70bd4404e78..e96f5f15c8d 100644 --- a/network/throttling/inbound_resource_throttler_test.go +++ b/network/throttling/inbound_resource_throttler_test.go @@ -40,8 +40,8 @@ func TestNewSystemThrottler(t *testing.T) { targeter := tracker.NewMockTargeter(ctrl) throttlerIntf, err := NewSystemThrottler("", reg, config, cpuTracker, targeter) require.NoError(err) - throttler, ok := throttlerIntf.(*systemThrottler) - require.True(ok) + require.IsType(&systemThrottler{}, throttlerIntf) + throttler := throttlerIntf.(*systemThrottler) require.Equal(clock, config.Clock) require.Equal(time.Second, config.MaxRecheckDelay) require.Equal(cpuTracker, throttler.tracker) diff --git a/snow/consensus/snowball/unary_snowball_test.go b/snow/consensus/snowball/unary_snowball_test.go index 012144bb85b..54c5e47b443 100644 --- a/snow/consensus/snowball/unary_snowball_test.go +++ b/snow/consensus/snowball/unary_snowball_test.go @@ -5,6 +5,8 @@ package snowball import ( "testing" + + "github.com/stretchr/testify/require" ) func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessfulPolls, expectedConfidence int, expectedFinalized bool) { @@ -18,6 +20,8 @@ func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessf } func TestUnarySnowball(t *testing.T) { + require := require.New(t) + beta := 2 sb := &unarySnowball{} @@ -33,10 +37,8 @@ func TestUnarySnowball(t *testing.T) { UnarySnowballStateTest(t, sb, 2, 1, false) sbCloneIntf := sb.Clone() - sbClone, ok := sbCloneIntf.(*unarySnowball) - if !ok { - t.Fatalf("Unexpected clone type") - } + require.IsType(&unarySnowball{}, sbCloneIntf) + sbClone := sbCloneIntf.(*unarySnowball) UnarySnowballStateTest(t, sbClone, 2, 1, false) diff --git a/snow/consensus/snowball/unary_snowflake_test.go b/snow/consensus/snowball/unary_snowflake_test.go index ab76c94a410..850c3116b81 100644 --- a/snow/consensus/snowball/unary_snowflake_test.go +++ b/snow/consensus/snowball/unary_snowflake_test.go @@ -5,6 +5,8 @@ package snowball import ( "testing" + + "github.com/stretchr/testify/require" ) func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidence int, expectedFinalized bool) { @@ -16,6 +18,8 @@ func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidenc } func TestUnarySnowflake(t *testing.T) { + require := require.New(t) + beta := 2 sf := &unarySnowflake{} @@ -31,10 +35,8 @@ func TestUnarySnowflake(t *testing.T) { UnarySnowflakeStateTest(t, sf, 1, false) sfCloneIntf := sf.Clone() - sfClone, ok := sfCloneIntf.(*unarySnowflake) - if !ok { - t.Fatalf("Unexpected clone type") - } + require.IsType(&unarySnowflake{}, sfCloneIntf) + sfClone := sfCloneIntf.(*unarySnowflake) UnarySnowflakeStateTest(t, sfClone, 1, false) diff --git a/snow/engine/avalanche/getter/getter_test.go b/snow/engine/avalanche/getter/getter_test.go index 613bb7b05b4..b4028d7cd8e 100644 --- a/snow/engine/avalanche/getter/getter_test.go +++ b/snow/engine/avalanche/getter/getter_test.go @@ -8,6 +8,8 @@ import ( "errors" "testing" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow" "github.com/ava-labs/avalanchego/snow/choices" @@ -62,6 +64,8 @@ func testSetup(t *testing.T) (*vertex.TestManager, *common.SenderTest, common.Co } func TestAcceptedFrontier(t *testing.T) { + require := require.New(t) + manager, sender, config := testSetup(t) vtxID0 := ids.GenerateTestID() @@ -72,10 +76,8 @@ func TestAcceptedFrontier(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*getter) - if !ok { - t.Fatal("Unexpected get handler") - } + require.IsType(&getter{}, bsIntf) + bs := bsIntf.(*getter) manager.EdgeF = func(context.Context) []ids.ID { return []ids.ID{ @@ -110,6 +112,8 @@ func TestAcceptedFrontier(t *testing.T) { } func TestFilterAccepted(t *testing.T) { + require := require.New(t) + manager, sender, config := testSetup(t) vtxID0 := ids.GenerateTestID() @@ -129,10 +133,8 @@ func TestFilterAccepted(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*getter) - if !ok { - t.Fatal("Unexpected get handler") - } + require.IsType(&getter{}, bsIntf) + bs := bsIntf.(*getter) vtxIDs := []ids.ID{vtxID0, vtxID1, vtxID2} diff --git a/snow/engine/snowman/bootstrap/bootstrapper_test.go b/snow/engine/snowman/bootstrap/bootstrapper_test.go index 9f644c18eed..d3f98fb55b1 100644 --- a/snow/engine/snowman/bootstrap/bootstrapper_test.go +++ b/snow/engine/snowman/bootstrap/bootstrapper_test.go @@ -1104,6 +1104,8 @@ func TestBootstrapperFinalized(t *testing.T) { } func TestRestartBootstrapping(t *testing.T) { + require := require.New(t) + config, peerID, sender, vm := newConfig(t) blkID0 := ids.Empty.Prefix(0) @@ -1238,10 +1240,8 @@ func TestRestartBootstrapping(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*bootstrapper) - if !ok { - t.Fatal("unexpected bootstrapper type") - } + require.IsType(&bootstrapper{}, bsIntf) + bs := bsIntf.(*bootstrapper) vm.CantSetState = false if err := bs.Start(context.Background(), 0); err != nil { @@ -1322,6 +1322,8 @@ func TestRestartBootstrapping(t *testing.T) { } func TestBootstrapOldBlockAfterStateSync(t *testing.T) { + require := require.New(t) + config, peerID, sender, vm := newConfig(t) blk0 := &snowman.TestBlock{ @@ -1380,10 +1382,8 @@ func TestBootstrapOldBlockAfterStateSync(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*bootstrapper) - if !ok { - t.Fatal("unexpected bootstrapper type") - } + require.IsType(&bootstrapper{}, bsIntf) + bs := bsIntf.(*bootstrapper) vm.CantSetState = false if err := bs.Start(context.Background(), 0); err != nil { @@ -1423,6 +1423,8 @@ func TestBootstrapOldBlockAfterStateSync(t *testing.T) { } func TestBootstrapContinueAfterHalt(t *testing.T) { + require := require.New(t) + config, _, _, vm := newConfig(t) blk0 := &snowman.TestBlock{ @@ -1469,10 +1471,8 @@ func TestBootstrapContinueAfterHalt(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*bootstrapper) - if !ok { - t.Fatal("unexpected bootstrapper type") - } + require.IsType(&bootstrapper{}, bsIntf) + bs := bsIntf.(*bootstrapper) vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { switch blkID { diff --git a/snow/engine/snowman/getter/getter_test.go b/snow/engine/snowman/getter/getter_test.go index 40e6fc03103..3dfdb9ad560 100644 --- a/snow/engine/snowman/getter/getter_test.go +++ b/snow/engine/snowman/getter/getter_test.go @@ -110,10 +110,8 @@ func TestAcceptedFrontier(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*getter) - if !ok { - t.Fatal("Unexpected get handler") - } + require.IsType(t, &getter{}, bsIntf) + bs := bsIntf.(*getter) var accepted []ids.ID sender.SendAcceptedFrontierF = func(_ context.Context, _ ids.NodeID, _ uint32, frontier []ids.ID) { @@ -164,10 +162,8 @@ func TestFilterAccepted(t *testing.T) { if err != nil { t.Fatal(err) } - bs, ok := bsIntf.(*getter) - if !ok { - t.Fatal("Unexpected get handler") - } + require.IsType(t, &getter{}, bsIntf) + bs := bsIntf.(*getter) blkIDs := []ids.ID{blkID0, blkID1, blkID2} vm.GetBlockF = func(_ context.Context, blkID ids.ID) (snowman.Block, error) { diff --git a/snow/engine/snowman/syncer/utils_test.go b/snow/engine/snowman/syncer/utils_test.go index 013037432e0..0150a876d97 100644 --- a/snow/engine/snowman/syncer/utils_test.go +++ b/snow/engine/snowman/syncer/utils_test.go @@ -89,8 +89,8 @@ func buildTestsObjects(t *testing.T, commonCfg *common.Config) ( commonSyncer := New(cfg, func(context.Context, uint32) error { return nil }) - syncer, ok := commonSyncer.(*stateSyncer) - require.True(t, ok) + require.IsType(t, &stateSyncer{}, commonSyncer) + syncer := commonSyncer.(*stateSyncer) require.True(t, syncer.stateSyncVM != nil) fullVM.GetOngoingSyncStateSummaryF = func(context.Context) (block.StateSummary, error) { diff --git a/snow/networking/sender/sender_test.go b/snow/networking/sender/sender_test.go index 04bfeea6d56..2950bbc8697 100644 --- a/snow/networking/sender/sender_test.go +++ b/snow/networking/sender/sender_test.go @@ -664,8 +664,8 @@ func TestSender_Bootstrap_Requests(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.GetStateSummaryFrontier) - require.True(ok) + require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(uint64(deadline), innerMsg.Deadline) @@ -709,8 +709,8 @@ func TestSender_Bootstrap_Requests(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.GetAcceptedStateSummary) - require.True(ok) + require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(uint64(deadline), innerMsg.Deadline) @@ -753,8 +753,8 @@ func TestSender_Bootstrap_Requests(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.GetAcceptedFrontier) - require.True(ok) + require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAcceptedFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(uint64(deadline), innerMsg.Deadline) @@ -798,8 +798,8 @@ func TestSender_Bootstrap_Requests(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.GetAccepted) - require.True(ok) + require.IsType(&p2p.GetAccepted{}, msg.Message()) + innerMsg := msg.Message().(*p2p.GetAccepted) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(uint64(deadline), innerMsg.Deadline) @@ -953,8 +953,8 @@ func TestSender_Bootstrap_Responses(t *testing.T) { ).Return(nil, nil) // Don't care about the message }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.StateSummaryFrontier) - require.True(ok) + require.IsType(&p2p.StateSummaryFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.StateSummaryFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) require.Equal(summary, innerMsg.Summary) @@ -981,8 +981,8 @@ func TestSender_Bootstrap_Responses(t *testing.T) { ).Return(nil, nil) // Don't care about the message }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.AcceptedStateSummary) - require.True(ok) + require.IsType(&p2p.AcceptedStateSummary{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AcceptedStateSummary) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) for i, summaryID := range summaryIDs { @@ -1011,8 +1011,8 @@ func TestSender_Bootstrap_Responses(t *testing.T) { ).Return(nil, nil) // Don't care about the message }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.AcceptedFrontier) - require.True(ok) + require.IsType(&p2p.AcceptedFrontier{}, msg.Message()) + innerMsg := msg.Message().(*p2p.AcceptedFrontier) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) for i, summaryID := range summaryIDs { @@ -1041,8 +1041,8 @@ func TestSender_Bootstrap_Responses(t *testing.T) { ).Return(nil, nil) // Don't care about the message }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*p2p.Accepted) - require.True(ok) + require.IsType(&p2p.Accepted{}, msg.Message()) + innerMsg := msg.Message().(*p2p.Accepted) require.Equal(chainID[:], innerMsg.ChainId) require.Equal(requestID, innerMsg.RequestId) for i, summaryID := range summaryIDs { @@ -1166,8 +1166,8 @@ func TestSender_Single_Request(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*message.GetAncestorsFailed) - require.True(ok) + require.IsType(&message.GetAncestorsFailed{}, msg.Message()) + innerMsg := msg.Message().(*message.GetAncestorsFailed) require.Equal(chainID, innerMsg.ChainID) require.Equal(requestID, innerMsg.RequestID) require.Equal(engineType, innerMsg.EngineType) @@ -1205,8 +1205,8 @@ func TestSender_Single_Request(t *testing.T) { ) }, assertMsgToMyself: func(require *require.Assertions, msg message.InboundMessage) { - innerMsg, ok := msg.Message().(*message.GetFailed) - require.True(ok) + require.IsType(&message.GetFailed{}, msg.Message()) + innerMsg := msg.Message().(*message.GetFailed) require.Equal(chainID, innerMsg.ChainID) require.Equal(requestID, innerMsg.RequestID) require.Equal(engineType, innerMsg.EngineType) diff --git a/snow/networking/tracker/resource_tracker_test.go b/snow/networking/tracker/resource_tracker_test.go index 11904e485ee..a60b9f25a66 100644 --- a/snow/networking/tracker/resource_tracker_test.go +++ b/snow/networking/tracker/resource_tracker_test.go @@ -27,8 +27,8 @@ func TestNewCPUTracker(t *testing.T) { trackerIntf, err := NewResourceTracker(reg, resource.NoUsage, factory, halflife) require.NoError(err) - tracker, ok := trackerIntf.(*resourceTracker) - require.True(ok) + require.IsType(&resourceTracker{}, trackerIntf) + tracker := trackerIntf.(*resourceTracker) require.Equal(factory, tracker.factory) require.NotNil(tracker.processingMeter) require.Equal(halflife, tracker.halflife) diff --git a/snow/networking/tracker/targeter_test.go b/snow/networking/tracker/targeter_test.go index 11d2cca4ae1..a70afbc8a99 100644 --- a/snow/networking/tracker/targeter_test.go +++ b/snow/networking/tracker/targeter_test.go @@ -33,8 +33,8 @@ func TestNewTargeter(t *testing.T) { vdrs, tracker, ) - targeter, ok := targeterIntf.(*targeter) - require.True(ok) + require.IsType(&targeter{}, targeterIntf) + targeter := targeterIntf.(*targeter) require.Equal(vdrs, targeter.vdrs) require.Equal(tracker, targeter.tracker) require.Equal(config.MaxNonVdrUsage, targeter.maxNonVdrUsage) diff --git a/utils/buffer/unbounded_deque_test.go b/utils/buffer/unbounded_deque_test.go index dcbfbe1c7a8..ea9ccd8782a 100644 --- a/utils/buffer/unbounded_deque_test.go +++ b/utils/buffer/unbounded_deque_test.go @@ -13,11 +13,11 @@ func TestUnboundedDeque_InitialCapGreaterThanMin(t *testing.T) { require := require.New(t) bIntf := NewUnboundedDeque[int](10) - b, ok := bIntf.(*unboundedSliceDeque[int]) - require.True(ok) + require.IsType(&unboundedSliceDeque[int]{}, bIntf) + b := bIntf.(*unboundedSliceDeque[int]) require.Empty(b.List()) require.Equal(0, b.Len()) - _, ok = b.Index(0) + _, ok := b.Index(0) require.False(ok) b.PushLeft(1) @@ -233,8 +233,8 @@ func TestUnboundedSliceDequePushLeftPopLeft(t *testing.T) { // Starts empty. bIntf := NewUnboundedDeque[int](2) - b, ok := bIntf.(*unboundedSliceDeque[int]) - require.True(ok) + require.IsType(&unboundedSliceDeque[int]{}, bIntf) + b := bIntf.(*unboundedSliceDeque[int]) require.Equal(0, bIntf.Len()) require.Equal(2, len(b.data)) require.Equal(0, b.left) @@ -242,7 +242,7 @@ func TestUnboundedSliceDequePushLeftPopLeft(t *testing.T) { require.Empty(b.List()) // slice is [EMPTY] - _, ok = b.PopLeft() + _, ok := b.PopLeft() require.False(ok) _, ok = b.PeekLeft() require.False(ok) @@ -416,8 +416,8 @@ func TestUnboundedSliceDequePushRightPopRight(t *testing.T) { // Starts empty. bIntf := NewUnboundedDeque[int](2) - b, ok := bIntf.(*unboundedSliceDeque[int]) - require.True(ok) + require.IsType(&unboundedSliceDeque[int]{}, bIntf) + b := bIntf.(*unboundedSliceDeque[int]) require.Equal(0, bIntf.Len()) require.Equal(2, len(b.data)) require.Equal(0, b.left) @@ -425,7 +425,7 @@ func TestUnboundedSliceDequePushRightPopRight(t *testing.T) { require.Empty(b.List()) // slice is [EMPTY] - _, ok = b.PopRight() + _, ok := b.PopRight() require.False(ok) _, ok = b.PeekLeft() require.False(ok) diff --git a/utils/dynamicip/updater_test.go b/utils/dynamicip/updater_test.go index c31031f988c..98ce26b4a18 100644 --- a/utils/dynamicip/updater_test.go +++ b/utils/dynamicip/updater_test.go @@ -44,8 +44,8 @@ func TestNewUpdater(t *testing.T) { ) // Assert NewUpdater returns expected type - updater, ok := updaterIntf.(*updater) - require.True(ok) + require.IsType(&updater{}, updaterIntf) + updater := updaterIntf.(*updater) // Assert fields set require.Equal(dynamicIP, updater.dynamicIP) diff --git a/utils/window/window_test.go b/utils/window/window_test.go index 8ca715674f8..b98bcc04f8c 100644 --- a/utils/window/window_test.go +++ b/utils/window/window_test.go @@ -151,8 +151,8 @@ func TestTTLOldest(t *testing.T) { TTL: testTTL, }, ) - window, ok := windowIntf.(*window[int]) - require.True(t, ok) + require.IsType(t, &window[int]{}, windowIntf) + window := windowIntf.(*window[int]) epochStart := time.Unix(0, 0) clock.Set(epochStart) diff --git a/vms/avm/blocks/block_test.go b/vms/avm/blocks/block_test.go index 0d0d7768846..4a20080391c 100644 --- a/vms/avm/blocks/block_test.go +++ b/vms/avm/blocks/block_test.go @@ -55,8 +55,8 @@ func TestStandardBlocks(t *testing.T) { require.Equal(standardBlk.Bytes(), parsed.Bytes()) require.Equal(standardBlk.Timestamp(), parsed.Timestamp()) - parsedStandardBlk, ok := parsed.(*StandardBlock) - require.True(ok) + require.IsType(&StandardBlock{}, parsed) + parsedStandardBlk := parsed.(*StandardBlock) require.Equal(txs, parsedStandardBlk.Txs()) require.Equal(parsed.Txs(), parsedStandardBlk.Txs()) diff --git a/vms/avm/blocks/builder/builder_test.go b/vms/avm/blocks/builder/builder_test.go index dbeaed63db0..a1501a0cddb 100644 --- a/vms/avm/blocks/builder/builder_test.go +++ b/vms/avm/blocks/builder/builder_test.go @@ -265,8 +265,8 @@ func TestBuilderBuildBlock(t *testing.T) { unsignedTx1.EXPECT().Visit(gomock.Any()).Return(nil) // Pass semantic verification unsignedTx1.EXPECT().Visit(gomock.Any()).DoAndReturn( // Pass execution func(visitor txs.Visitor) error { - executor, ok := visitor.(*txexecutor.Executor) - require.True(t, ok) + require.IsType(t, &txexecutor.Executor{}, visitor) + executor := visitor.(*txexecutor.Executor) executor.Inputs.Add(inputID) return nil }, @@ -282,8 +282,8 @@ func TestBuilderBuildBlock(t *testing.T) { unsignedTx2.EXPECT().Visit(gomock.Any()).Return(nil) // Pass semantic verification unsignedTx2.EXPECT().Visit(gomock.Any()).DoAndReturn( // Pass execution func(visitor txs.Visitor) error { - executor, ok := visitor.(*txexecutor.Executor) - require.True(t, ok) + require.IsType(t, &txexecutor.Executor{}, visitor) + executor := visitor.(*txexecutor.Executor) executor.Inputs.Add(inputID) return nil }, @@ -374,8 +374,8 @@ func TestBuilderBuildBlock(t *testing.T) { unsignedTx.EXPECT().Visit(gomock.Any()).Return(nil) // Pass semantic verification unsignedTx.EXPECT().Visit(gomock.Any()).DoAndReturn( // Pass execution func(visitor txs.Visitor) error { - executor, ok := visitor.(*txexecutor.Executor) - require.True(t, ok) + require.IsType(t, &txexecutor.Executor{}, visitor) + executor := visitor.(*txexecutor.Executor) executor.Inputs.Add(inputID) return nil }, @@ -448,8 +448,8 @@ func TestBuilderBuildBlock(t *testing.T) { unsignedTx.EXPECT().Visit(gomock.Any()).Return(nil) // Pass semantic verification unsignedTx.EXPECT().Visit(gomock.Any()).DoAndReturn( // Pass execution func(visitor txs.Visitor) error { - executor, ok := visitor.(*txexecutor.Executor) - require.True(t, ok) + require.IsType(t, &txexecutor.Executor{}, visitor) + executor := visitor.(*txexecutor.Executor) executor.Inputs.Add(inputID) return nil }, diff --git a/vms/avm/network/network_test.go b/vms/avm/network/network_test.go index ec32a6d3780..0e0cb76aa34 100644 --- a/vms/avm/network/network_test.go +++ b/vms/avm/network/network_test.go @@ -326,8 +326,8 @@ func TestNetworkGossipTx(t *testing.T) { mempool.NewMockMempool(ctrl), appSender, ) - n, ok := nIntf.(*network) - require.True(ok) + require.IsType(&network{}, nIntf) + n := nIntf.(*network) // Case: Tx was recently gossiped txID := ids.GenerateTestID() diff --git a/vms/components/chain/state_test.go b/vms/components/chain/state_test.go index 100c9d0a2ab..30a14fa67d6 100644 --- a/vms/components/chain/state_test.go +++ b/vms/components/chain/state_test.go @@ -135,9 +135,9 @@ func cantBuildBlock(context.Context) (snowman.Block, error) { // checkProcessingBlock checks that [blk] is of the correct type and is // correctly uniquified when calling GetBlock and ParseBlock. func checkProcessingBlock(t *testing.T, s *State, blk snowman.Block) { - if _, ok := blk.(*BlockWrapper); !ok { - t.Fatalf("Expected block to be of type (*BlockWrapper)") - } + require := require.New(t) + + require.IsType(&BlockWrapper{}, blk) parsedBlk, err := s.ParseBlock(context.Background(), blk.Bytes()) if err != nil { @@ -168,9 +168,9 @@ func checkProcessingBlock(t *testing.T, s *State, blk snowman.Block) { // checkDecidedBlock asserts that [blk] is returned with the correct status by ParseBlock // and GetBlock. func checkDecidedBlock(t *testing.T, s *State, blk snowman.Block, expectedStatus choices.Status, cached bool) { - if _, ok := blk.(*BlockWrapper); !ok { - t.Fatalf("Expected block to be of type (*BlockWrapper)") - } + require := require.New(t) + + require.IsType(&BlockWrapper{}, blk) parsedBlk, err := s.ParseBlock(context.Background(), blk.Bytes()) if err != nil { @@ -521,6 +521,7 @@ func TestStateParent(t *testing.T) { } func TestGetBlockInternal(t *testing.T) { + require := require.New(t) testBlks := NewTestBlocks(1) genesisBlock := testBlks[0] genesisBlock.SetStatus(choices.Accepted) @@ -539,9 +540,7 @@ func TestGetBlockInternal(t *testing.T) { }) genesisBlockInternal := chainState.LastAcceptedBlockInternal() - if _, ok := genesisBlockInternal.(*TestBlock); !ok { - t.Fatalf("Expected LastAcceptedBlockInternal to return a block of type *snowman.TestBlock, but found %T", genesisBlockInternal) - } + require.IsType(&TestBlock{}, genesisBlockInternal) if genesisBlockInternal.ID() != genesisBlock.ID() { t.Fatalf("Expected LastAcceptedBlockInternal to be blk %s, but found %s", genesisBlock.ID(), genesisBlockInternal.ID()) } @@ -551,9 +550,7 @@ func TestGetBlockInternal(t *testing.T) { t.Fatal(err) } - if _, ok := blk.(*TestBlock); !ok { - t.Fatalf("Expected retrieved block to return a block of type *snowman.TestBlock, but found %T", blk) - } + require.IsType(&TestBlock{}, blk) if blk.ID() != genesisBlock.ID() { t.Fatalf("Expected GetBlock to be blk %s, but found %s", genesisBlock.ID(), blk.ID()) } diff --git a/vms/components/message/message_test.go b/vms/components/message/message_test.go index 89a8f3f7547..38b5099e2b4 100644 --- a/vms/components/message/message_test.go +++ b/vms/components/message/message_test.go @@ -38,8 +38,8 @@ func TestParseProto(t *testing.T) { parsedMsgIntf, err := Parse(msgBytes) require.NoError(err) - parsedMsg, ok := parsedMsgIntf.(*Tx) - require.True(ok) + require.IsType(&Tx{}, parsedMsgIntf) + parsedMsg := parsedMsgIntf.(*Tx) require.Equal(txBytes, parsedMsg.Tx) diff --git a/vms/components/message/tx_test.go b/vms/components/message/tx_test.go index 58a06e1bd77..3634abb4c71 100644 --- a/vms/components/message/tx_test.go +++ b/vms/components/message/tx_test.go @@ -27,8 +27,8 @@ func TestTx(t *testing.T) { require.NoError(err) require.Equal(builtMsgBytes, parsedMsgIntf.Bytes()) - parsedMsg, ok := parsedMsgIntf.(*Tx) - require.True(ok) + require.IsType(&Tx{}, parsedMsgIntf) + parsedMsg := parsedMsgIntf.(*Tx) require.Equal(tx, parsedMsg.Tx) } diff --git a/vms/platformvm/blocks/builder/builder_test.go b/vms/platformvm/blocks/builder/builder_test.go index f10791e2972..3f3efe80aff 100644 --- a/vms/platformvm/blocks/builder/builder_test.go +++ b/vms/platformvm/blocks/builder/builder_test.go @@ -61,8 +61,8 @@ func TestBlockBuilderAddLocalTx(t *testing.T) { blkIntf, err := env.Builder.BuildBlock(context.Background()) require.NoError(err) - blk, ok := blkIntf.(*blockexecutor.Block) - require.True(ok) + require.IsType(&blockexecutor.Block{}, blkIntf) + blk := blkIntf.(*blockexecutor.Block) require.Len(blk.Txs(), 1) require.Equal(txID, blk.Txs()[0].ID()) diff --git a/vms/platformvm/blocks/executor/manager_test.go b/vms/platformvm/blocks/executor/manager_test.go index 4bd341e2365..5e488a98a64 100644 --- a/vms/platformvm/blocks/executor/manager_test.go +++ b/vms/platformvm/blocks/executor/manager_test.go @@ -44,8 +44,8 @@ func TestGetBlock(t *testing.T) { gotBlk, err := manager.GetBlock(statelessBlk.ID()) require.NoError(err) require.Equal(statelessBlk.ID(), gotBlk.ID()) - innerBlk, ok := gotBlk.(*Block) - require.True(ok) + require.IsType(&Block{}, gotBlk) + innerBlk := gotBlk.(*Block) require.Equal(statelessBlk, innerBlk.Block) require.Equal(manager, innerBlk.manager) } @@ -57,8 +57,8 @@ func TestGetBlock(t *testing.T) { gotBlk, err := manager.GetBlock(statelessBlk.ID()) require.NoError(err) require.Equal(statelessBlk.ID(), gotBlk.ID()) - innerBlk, ok := gotBlk.(*Block) - require.True(ok) + require.IsType(&Block{}, gotBlk) + innerBlk := gotBlk.(*Block) require.Equal(statelessBlk, innerBlk.Block) require.Equal(manager, innerBlk.manager) } diff --git a/vms/platformvm/blocks/parse_test.go b/vms/platformvm/blocks/parse_test.go index 9ab17af75bc..a254a5e76e0 100644 --- a/vms/platformvm/blocks/parse_test.go +++ b/vms/platformvm/blocks/parse_test.go @@ -43,8 +43,7 @@ func TestStandardBlocks(t *testing.T) { require.Equal(apricotStandardBlk.Parent(), parsed.Parent()) require.Equal(apricotStandardBlk.Height(), parsed.Height()) - _, ok := parsed.(*ApricotStandardBlock) - require.True(ok) + require.IsType(&ApricotStandardBlock{}, parsed) require.Equal(txs, parsed.Txs()) // check that banff standard block can be built and parsed @@ -60,8 +59,8 @@ func TestStandardBlocks(t *testing.T) { require.Equal(banffStandardBlk.Bytes(), parsed.Bytes()) require.Equal(banffStandardBlk.Parent(), parsed.Parent()) require.Equal(banffStandardBlk.Height(), parsed.Height()) - parsedBanffStandardBlk, ok := parsed.(*BanffStandardBlock) - require.True(ok) + require.IsType(&BanffStandardBlock{}, parsed) + parsedBanffStandardBlk := parsed.(*BanffStandardBlock) require.Equal(txs, parsedBanffStandardBlk.Txs()) // timestamp check for banff blocks only @@ -100,8 +99,8 @@ func TestProposalBlocks(t *testing.T) { require.Equal(apricotProposalBlk.Parent(), parsed.Parent()) require.Equal(apricotProposalBlk.Height(), parsed.Height()) - parsedApricotProposalBlk, ok := parsed.(*ApricotProposalBlock) - require.True(ok) + require.IsType(&ApricotProposalBlock{}, parsed) + parsedApricotProposalBlk := parsed.(*ApricotProposalBlock) require.Equal([]*txs.Tx{tx}, parsedApricotProposalBlk.Txs()) // check that banff proposal block can be built and parsed @@ -122,8 +121,8 @@ func TestProposalBlocks(t *testing.T) { require.Equal(banffProposalBlk.Bytes(), parsed.Bytes()) require.Equal(banffProposalBlk.Parent(), banffProposalBlk.Parent()) require.Equal(banffProposalBlk.Height(), parsed.Height()) - parsedBanffProposalBlk, ok := parsed.(*BanffProposalBlock) - require.True(ok) + require.IsType(&BanffProposalBlock{}, parsed) + parsedBanffProposalBlk := parsed.(*BanffProposalBlock) require.Equal([]*txs.Tx{tx}, parsedBanffProposalBlk.Txs()) // timestamp check for banff blocks only @@ -171,8 +170,8 @@ func TestCommitBlock(t *testing.T) { require.Equal(banffCommitBlk.Height(), parsed.Height()) // timestamp check for banff blocks only - parsedBanffCommitBlk, ok := parsed.(*BanffCommitBlock) - require.True(ok) + require.IsType(&BanffCommitBlock{}, parsed) + parsedBanffCommitBlk := parsed.(*BanffCommitBlock) require.Equal(banffCommitBlk.Timestamp(), parsedBanffCommitBlk.Timestamp()) } } @@ -214,8 +213,8 @@ func TestAbortBlock(t *testing.T) { require.Equal(banffAbortBlk.Height(), parsed.Height()) // timestamp check for banff blocks only - parsedBanffAbortBlk, ok := parsed.(*BanffAbortBlock) - require.True(ok) + require.IsType(&BanffAbortBlock{}, parsed) + parsedBanffAbortBlk := parsed.(*BanffAbortBlock) require.Equal(banffAbortBlk.Timestamp(), parsedBanffAbortBlk.Timestamp()) } } @@ -247,8 +246,8 @@ func TestAtomicBlock(t *testing.T) { require.Equal(atomicBlk.Parent(), parsed.Parent()) require.Equal(atomicBlk.Height(), parsed.Height()) - parsedAtomicBlk, ok := parsed.(*ApricotAtomicBlock) - require.True(ok) + require.IsType(&ApricotAtomicBlock{}, parsed) + parsedAtomicBlk := parsed.(*ApricotAtomicBlock) require.Equal([]*txs.Tx{tx}, parsedAtomicBlk.Txs()) } } diff --git a/vms/platformvm/service_test.go b/vms/platformvm/service_test.go index 1cc0abf5e86..8195a7ac3c9 100644 --- a/vms/platformvm/service_test.go +++ b/vms/platformvm/service_test.go @@ -352,8 +352,7 @@ func TestGetTx(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) err := commit.Verify(context.Background()) require.NoError(err) diff --git a/vms/platformvm/txs/executor/reward_validator_test.go b/vms/platformvm/txs/executor/reward_validator_test.go index cd34cbdac06..c7e5058952e 100644 --- a/vms/platformvm/txs/executor/reward_validator_test.go +++ b/vms/platformvm/txs/executor/reward_validator_test.go @@ -457,8 +457,8 @@ func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { utxo, err := onCommitState.GetUTXO(delRewardUTXOID.InputID()) require.NoError(err) - castUTXO, ok := utxo.Out.(*secp256k1fx.TransferOutput) - require.True(ok) + require.IsType(&secp256k1fx.TransferOutput{}, utxo.Out) + castUTXO := utxo.Out.(*secp256k1fx.TransferOutput) require.Equal(delRewardAmt*3/4, castUTXO.Amt, "expected delegator balance to increase by 3/4 of reward amount") require.True(delDestSet.Equals(castUTXO.AddressesSet()), "expected reward UTXO to be issued to delDestSet") @@ -505,8 +505,8 @@ func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { utxo, err = onCommitState.GetUTXO(vdrRewardUTXOID.InputID()) require.NoError(err) - castUTXO, ok = utxo.Out.(*secp256k1fx.TransferOutput) - require.True(ok) + require.IsType(&secp256k1fx.TransferOutput{}, utxo.Out) + castUTXO = utxo.Out.(*secp256k1fx.TransferOutput) require.Equal(vdrRewardAmt, castUTXO.Amt, "expected validator to be rewarded") require.True(vdrDestSet.Equals(castUTXO.AddressesSet()), "expected reward UTXO to be issued to vdrDestSet") @@ -518,8 +518,8 @@ func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { utxo, err = onCommitState.GetUTXO(onCommitVdrDelRewardUTXOID.InputID()) require.NoError(err) - castUTXO, ok = utxo.Out.(*secp256k1fx.TransferOutput) - require.True(ok) + require.IsType(&secp256k1fx.TransferOutput{}, utxo.Out) + castUTXO = utxo.Out.(*secp256k1fx.TransferOutput) require.Equal(delRewardAmt/4, castUTXO.Amt, "expected validator to be rewarded with accrued delegator rewards") require.True(vdrDestSet.Equals(castUTXO.AddressesSet()), "expected reward UTXO to be issued to vdrDestSet") @@ -531,8 +531,8 @@ func TestRewardDelegatorTxExecuteOnCommitPostDelegateeDeferral(t *testing.T) { utxo, err = onAbortState.GetUTXO(onAbortVdrDelRewardUTXOID.InputID()) require.NoError(err) - castUTXO, ok = utxo.Out.(*secp256k1fx.TransferOutput) - require.True(ok) + require.IsType(&secp256k1fx.TransferOutput{}, utxo.Out) + castUTXO = utxo.Out.(*secp256k1fx.TransferOutput) require.Equal(delRewardAmt/4, castUTXO.Amt, "expected validator to be rewarded with accrued delegator rewards") require.True(vdrDestSet.Equals(castUTXO.AddressesSet()), "expected reward UTXO to be issued to vdrDestSet") @@ -688,8 +688,8 @@ func TestRewardDelegatorTxAndValidatorTxExecuteOnCommitPostDelegateeDeferral(t * utxo, err := vdrOnAbortState.GetUTXO(onAbortVdrDelRewardUTXOID.InputID()) require.NoError(err) - castUTXO, ok := utxo.Out.(*secp256k1fx.TransferOutput) - require.True(ok) + require.IsType(&secp256k1fx.TransferOutput{}, utxo.Out) + castUTXO := utxo.Out.(*secp256k1fx.TransferOutput) require.Equal(delRewardAmt/4, castUTXO.Amt, "expected validator to be rewarded with accrued delegator rewards") require.True(vdrDestSet.Equals(castUTXO.AddressesSet()), "expected reward UTXO to be issued to vdrDestSet") diff --git a/vms/platformvm/vm_test.go b/vms/platformvm/vm_test.go index b3bc441e56e..fcf7b45a8c8 100644 --- a/vms/platformvm/vm_test.go +++ b/vms/platformvm/vm_test.go @@ -856,11 +856,9 @@ func TestRewardValidatorAccept(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort := options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -900,12 +898,10 @@ func TestRewardValidatorAccept(t *testing.T) { require.NoError(err) commit = options[0].(*blockexecutor.Block) - _, ok = commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort = options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -954,12 +950,10 @@ func TestRewardValidatorReject(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort := options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -995,12 +989,10 @@ func TestRewardValidatorReject(t *testing.T) { require.NoError(err) commit = options[0].(*blockexecutor.Block) - _, ok = commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort = options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(blk.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -1049,12 +1041,10 @@ func TestRewardValidatorPreferred(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort := options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -1091,12 +1081,10 @@ func TestRewardValidatorPreferred(t *testing.T) { require.NoError(err) commit = options[0].(*blockexecutor.Block) - _, ok = commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort = options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(blk.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -2195,12 +2183,10 @@ func TestUptimeDisallowedWithRestart(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort := options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -2238,12 +2224,10 @@ func TestUptimeDisallowedWithRestart(t *testing.T) { require.NoError(err) commit = options[0].(*blockexecutor.Block) - _, ok = commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort = options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(blk.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -2336,12 +2320,10 @@ func TestUptimeDisallowedAfterNeverConnecting(t *testing.T) { require.NoError(err) commit := options[0].(*blockexecutor.Block) - _, ok := commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort := options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(block.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) @@ -2364,12 +2346,10 @@ func TestUptimeDisallowedAfterNeverConnecting(t *testing.T) { require.NoError(err) commit = options[0].(*blockexecutor.Block) - _, ok = commit.Block.(*blocks.BanffCommitBlock) - require.True(ok) + require.IsType(&blocks.BanffCommitBlock{}, commit.Block) abort = options[1].(*blockexecutor.Block) - _, ok = abort.Block.(*blocks.BanffAbortBlock) - require.True(ok) + require.IsType(&blocks.BanffAbortBlock{}, abort.Block) require.NoError(blk.Accept(context.Background())) require.NoError(commit.Verify(context.Background())) diff --git a/vms/proposervm/batched_vm_test.go b/vms/proposervm/batched_vm_test.go index 153cd5b4ee1..70e462cffed 100644 --- a/vms/proposervm/batched_vm_test.go +++ b/vms/proposervm/batched_vm_test.go @@ -367,8 +367,7 @@ func TestGetAncestorsAtSnomanPlusPlusFork(t *testing.T) { } builtBlk1, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build preFork block") - _, ok := builtBlk1.(*preForkBlock) - require.True(ok, "Block should be a pre-fork one") + require.IsType(&preForkBlock{}, builtBlk1) // prepare build of next block require.NoError(proRemoteVM.SetPreference(context.Background(), builtBlk1.ID())) @@ -396,8 +395,7 @@ func TestGetAncestorsAtSnomanPlusPlusFork(t *testing.T) { } builtBlk2, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk2.(*preForkBlock) - require.True(ok, "Block should be a pre-fork one") + require.IsType(&preForkBlock{}, builtBlk2) // prepare build of next block require.NoError(proRemoteVM.SetPreference(context.Background(), builtBlk2.ID())) @@ -427,8 +425,7 @@ func TestGetAncestorsAtSnomanPlusPlusFork(t *testing.T) { } builtBlk3, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk3.(*postForkBlock) - require.True(ok, "Block should be a post-fork one") + require.IsType(&postForkBlock{}, builtBlk3) // prepare build of next block require.NoError(builtBlk3.Verify(context.Background())) @@ -450,8 +447,7 @@ func TestGetAncestorsAtSnomanPlusPlusFork(t *testing.T) { } builtBlk4, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk4.(*postForkBlock) - require.True(ok, "Block should be a post-fork one") + require.IsType(&postForkBlock{}, builtBlk4) require.NoError(builtBlk4.Verify(context.Background())) // ...Call GetAncestors on them ... @@ -797,8 +793,7 @@ func TestBatchedParseBlockAtSnomanPlusPlusFork(t *testing.T) { } builtBlk1, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build preFork block") - _, ok := builtBlk1.(*preForkBlock) - require.True(ok, "Block should be a pre-fork one") + require.IsType(&preForkBlock{}, builtBlk1) // prepare build of next block require.NoError(proRemoteVM.SetPreference(context.Background(), builtBlk1.ID())) @@ -826,8 +821,7 @@ func TestBatchedParseBlockAtSnomanPlusPlusFork(t *testing.T) { } builtBlk2, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk2.(*preForkBlock) - require.True(ok, "Block should be a pre-fork one") + require.IsType(&preForkBlock{}, builtBlk2) // prepare build of next block require.NoError(proRemoteVM.SetPreference(context.Background(), builtBlk2.ID())) @@ -857,8 +851,7 @@ func TestBatchedParseBlockAtSnomanPlusPlusFork(t *testing.T) { } builtBlk3, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk3.(*postForkBlock) - require.True(ok, "Block should be a post-fork one") + require.IsType(&postForkBlock{}, builtBlk3) // prepare build of next block require.NoError(builtBlk3.Verify(context.Background())) @@ -880,8 +873,7 @@ func TestBatchedParseBlockAtSnomanPlusPlusFork(t *testing.T) { } builtBlk4, err := proRemoteVM.BuildBlock(context.Background()) require.NoError(err, "Could not build proposer block") - _, ok = builtBlk4.(*postForkBlock) - require.True(ok, "Block should be a post-fork one") + require.IsType(&postForkBlock{}, builtBlk4) require.NoError(builtBlk4.Verify(context.Background())) coreVM.ParseBlockF = func(_ context.Context, b []byte) (snowman.Block, error) { diff --git a/vms/proposervm/post_fork_block_test.go b/vms/proposervm/post_fork_block_test.go index f4912172700..441fb79979c 100644 --- a/vms/proposervm/post_fork_block_test.go +++ b/vms/proposervm/post_fork_block_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/choices" @@ -554,6 +556,8 @@ func TestBlockVerify_PostForkBlock_PChainHeightChecks(t *testing.T) { } func TestBlockVerify_PostForkBlockBuiltOnOption_PChainHeightChecks(t *testing.T) { + require := require.New(t) + coreVM, valState, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) // enable ProBlks pChainHeight := uint64(100) valState.GetCurrentHeightF = func(context.Context) (uint64, error) { @@ -639,10 +643,8 @@ func TestBlockVerify_PostForkBlockBuiltOnOption_PChainHeightChecks(t *testing.T) } // retrieve one option and verify block built on it - postForkOracleBlk, ok := oracleBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, oracleBlk) + postForkOracleBlk := oracleBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") @@ -943,6 +945,8 @@ func TestBlockAccept_PostForkBlock_TwoProBlocksWithSameCoreBlock_OneIsAccepted(t // ProposerBlock.Reject tests section func TestBlockReject_PostForkBlock_InnerBlockIsNotRejected(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) // enable ProBlks coreBlk := &snowman.TestBlock{ TestDecidable: choices.TestDecidable{ @@ -962,10 +966,8 @@ func TestBlockReject_PostForkBlock_InnerBlockIsNotRejected(t *testing.T) { if err != nil { t.Fatal("could not build block") } - proBlk, ok := sb.(*postForkBlock) - if !ok { - t.Fatal("built block has not expected type") - } + require.IsType(&postForkBlock{}, sb) + proBlk := sb.(*postForkBlock) if err := proBlk.Reject(context.Background()); err != nil { t.Fatal("could not reject block") @@ -981,6 +983,8 @@ func TestBlockReject_PostForkBlock_InnerBlockIsNotRejected(t *testing.T) { } func TestBlockVerify_PostForkBlock_ShouldBePostForkOption(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -1064,17 +1068,13 @@ func TestBlockVerify_PostForkBlock_ShouldBePostForkOption(t *testing.T) { } // retrieve options ... - postForkOracleBlk, ok := parentBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, parentBlk) + postForkOracleBlk := parentBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") } - if _, ok := opts[0].(*postForkOption); !ok { - t.Fatal("unexpected option type") - } + require.IsType(&postForkOption{}, opts[0]) // ... and verify them the first time if err := opts[0].Verify(context.Background()); err != nil { diff --git a/vms/proposervm/post_fork_option_test.go b/vms/proposervm/post_fork_option_test.go index 3f6c0ac6bec..690fdf75f71 100644 --- a/vms/proposervm/post_fork_option_test.go +++ b/vms/proposervm/post_fork_option_test.go @@ -10,6 +10,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/database/manager" "github.com/ava-labs/avalanchego/ids" @@ -35,6 +37,8 @@ func (tob TestOptionsBlock) Options(context.Context) ([2]snowman.Block, error) { // ProposerBlock.Verify tests section func TestBlockVerify_PostForkOption_ParentChecks(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -116,17 +120,13 @@ func TestBlockVerify_PostForkOption_ParentChecks(t *testing.T) { } // retrieve options ... - postForkOracleBlk, ok := parentBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, parentBlk) + postForkOracleBlk := parentBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") } - if _, ok := opts[0].(*postForkOption); !ok { - t.Fatal("unexpected option type") - } + require.IsType(&postForkOption{}, opts[0]) // ... and verify them if err := opts[0].Verify(context.Background()); err != nil { @@ -159,9 +159,7 @@ func TestBlockVerify_PostForkOption_ParentChecks(t *testing.T) { if err != nil { t.Fatal("could not build on top of option") } - if _, ok := proChild.(*postForkBlock); !ok { - t.Fatal("unexpected block type") - } + require.IsType(&postForkBlock{}, proChild) if err := proChild.Verify(context.Background()); err != nil { t.Fatal("block built on option does not verify") } @@ -169,6 +167,8 @@ func TestBlockVerify_PostForkOption_ParentChecks(t *testing.T) { // ProposerBlock.Accept tests section func TestBlockVerify_PostForkOption_CoreBlockVerifyIsCalledOnce(t *testing.T) { + require := require.New(t) + // Verify an option once; then show that another verify call would not call coreBlk.Verify() coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -253,17 +253,13 @@ func TestBlockVerify_PostForkOption_CoreBlockVerifyIsCalledOnce(t *testing.T) { } // retrieve options ... - postForkOracleBlk, ok := parentBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, parentBlk) + postForkOracleBlk := parentBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") } - if _, ok := opts[0].(*postForkOption); !ok { - t.Fatal("unexpected option type") - } + require.IsType(&postForkOption{}, opts[0]) // ... and verify them the first time if err := opts[0].Verify(context.Background()); err != nil { @@ -287,7 +283,8 @@ func TestBlockVerify_PostForkOption_CoreBlockVerifyIsCalledOnce(t *testing.T) { } func TestBlockAccept_PostForkOption_SetsLastAcceptedBlock(t *testing.T) { - // setup + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -379,10 +376,8 @@ func TestBlockAccept_PostForkOption_SetsLastAcceptedBlock(t *testing.T) { } // accept one of the options - postForkOracleBlk, ok := parentBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, parentBlk) + postForkOracleBlk := parentBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") @@ -407,7 +402,8 @@ func TestBlockAccept_PostForkOption_SetsLastAcceptedBlock(t *testing.T) { // ProposerBlock.Reject tests section func TestBlockReject_InnerBlockIsNotRejected(t *testing.T) { - // setup + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -485,10 +481,8 @@ func TestBlockReject_InnerBlockIsNotRejected(t *testing.T) { if err := builtBlk.Reject(context.Background()); err != nil { t.Fatal("could not reject block") } - proBlk, ok := builtBlk.(*postForkBlock) - if !ok { - t.Fatal("built block has not expected type") - } + require.IsType(&postForkBlock{}, builtBlk) + proBlk := builtBlk.(*postForkBlock) if proBlk.Status() != choices.Rejected { t.Fatal("block rejection did not set state properly") @@ -499,10 +493,8 @@ func TestBlockReject_InnerBlockIsNotRejected(t *testing.T) { } // reject an option - postForkOracleBlk, ok := builtBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, builtBlk) + postForkOracleBlk := builtBlk.(*postForkBlock) opts, err := postForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") @@ -511,10 +503,8 @@ func TestBlockReject_InnerBlockIsNotRejected(t *testing.T) { if err := opts[0].Reject(context.Background()); err != nil { t.Fatal("could not accept option") } - proOpt, ok := opts[0].(*postForkOption) - if !ok { - t.Fatal("built block has not expected type") - } + require.IsType(&postForkOption{}, opts[0]) + proOpt := opts[0].(*postForkOption) if proOpt.Status() != choices.Rejected { t.Fatal("block rejection did not set state properly") @@ -526,6 +516,8 @@ func TestBlockReject_InnerBlockIsNotRejected(t *testing.T) { } func TestBlockVerify_PostForkOption_ParentIsNotOracleWithError(t *testing.T) { + require := require.New(t) + // Verify an option once; then show that another verify call would not call coreBlk.Verify() coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -587,10 +579,8 @@ func TestBlockVerify_PostForkOption_ParentIsNotOracleWithError(t *testing.T) { t.Fatal("could not build post fork oracle block") } - postForkBlk, ok := parentBlk.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, parentBlk) + postForkBlk := parentBlk.(*postForkBlock) _, err = postForkBlk.Options(context.Background()) if err != snowman.ErrNotOracle { t.Fatal("should have reported that the block isn't an oracle block") diff --git a/vms/proposervm/pre_fork_block_test.go b/vms/proposervm/pre_fork_block_test.go index 08df36b03b0..ada0048c079 100644 --- a/vms/proposervm/pre_fork_block_test.go +++ b/vms/proposervm/pre_fork_block_test.go @@ -51,6 +51,8 @@ func TestOracle_PreForkBlkImplementsInterface(t *testing.T) { } func TestOracle_PreForkBlkCanBuiltOnPreForkOption(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, mockable.MaxTime, 0) // create pre fork oracle block ... @@ -107,10 +109,8 @@ func TestOracle_PreForkBlkCanBuiltOnPreForkOption(t *testing.T) { } // retrieve options ... - preForkOracleBlk, ok := parentBlk.(*preForkBlock) - if !ok { - t.Fatal("expected pre fork block") - } + require.IsType(&preForkBlock{}, parentBlk) + preForkOracleBlk := parentBlk.(*preForkBlock) opts, err := preForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from pre fork oracle block") @@ -142,12 +142,12 @@ func TestOracle_PreForkBlkCanBuiltOnPreForkOption(t *testing.T) { if err != nil { t.Fatal("could not build pre fork block on pre fork option block") } - if _, ok := preForkChild.(*preForkBlock); !ok { - t.Fatal("expected pre fork block built on pre fork option block") - } + require.IsType(&preForkBlock{}, preForkChild) } func TestOracle_PostForkBlkCanBuiltOnPreForkOption(t *testing.T) { + require := require.New(t) + activationTime := genesisTimestamp.Add(10 * time.Second) coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, activationTime, 0) @@ -210,10 +210,8 @@ func TestOracle_PostForkBlkCanBuiltOnPreForkOption(t *testing.T) { } // retrieve options ... - preForkOracleBlk, ok := parentBlk.(*preForkBlock) - if !ok { - t.Fatal("expected pre fork block") - } + require.IsType(&preForkBlock{}, parentBlk) + preForkOracleBlk := parentBlk.(*preForkBlock) opts, err := preForkOracleBlk.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from pre fork oracle block") @@ -245,9 +243,7 @@ func TestOracle_PostForkBlkCanBuiltOnPreForkOption(t *testing.T) { if err != nil { t.Fatal("could not build pre fork block on pre fork option block") } - if _, ok := postForkChild.(*postForkBlock); !ok { - t.Fatal("expected pre fork block built on pre fork option block") - } + require.IsType(&postForkBlock{}, postForkChild) } func TestBlockVerify_PreFork_ParentChecks(t *testing.T) { @@ -327,6 +323,8 @@ func TestBlockVerify_PreFork_ParentChecks(t *testing.T) { } func TestBlockVerify_BlocksBuiltOnPreForkGenesis(t *testing.T) { + require := require.New(t) + activationTime := genesisTimestamp.Add(10 * time.Second) coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, activationTime, 0) if !coreGenBlk.Timestamp().Before(activationTime) { @@ -353,9 +351,8 @@ func TestBlockVerify_BlocksBuiltOnPreForkGenesis(t *testing.T) { preForkChild, err := proVM.BuildBlock(context.Background()) if err != nil { t.Fatalf("unexpectedly could not build block due to %s", err) - } else if _, ok := preForkChild.(*preForkBlock); !ok { - t.Fatal("expected preForkBlock") } + require.IsType(&preForkBlock{}, preForkChild) if err := preForkChild.Verify(context.Background()); err != nil { t.Fatal("pre Fork blocks should verify before fork") @@ -428,9 +425,8 @@ func TestBlockVerify_BlocksBuiltOnPreForkGenesis(t *testing.T) { lastPreForkBlk, err := proVM.BuildBlock(context.Background()) if err != nil { t.Fatalf("unexpectedly could not build block due to %s", err) - } else if _, ok := lastPreForkBlk.(*preForkBlock); !ok { - t.Fatal("expected preForkBlock") } + require.IsType(&preForkBlock{}, lastPreForkBlk) if err := lastPreForkBlk.Verify(context.Background()); err != nil { t.Fatal("pre Fork blocks should verify before fork") @@ -468,9 +464,8 @@ func TestBlockVerify_BlocksBuiltOnPreForkGenesis(t *testing.T) { firstPostForkBlk, err := proVM.BuildBlock(context.Background()) if err != nil { t.Fatalf("unexpectedly could not build block due to %s", err) - } else if _, ok := firstPostForkBlk.(*postForkBlock); !ok { - t.Fatal("expected preForkBlock") } + require.IsType(&postForkBlock{}, firstPostForkBlk) if err := firstPostForkBlk.Verify(context.Background()); err != nil { t.Fatal("pre Fork blocks should verify before fork") @@ -478,6 +473,8 @@ func TestBlockVerify_BlocksBuiltOnPreForkGenesis(t *testing.T) { } func TestBlockVerify_BlocksBuiltOnPostForkGenesis(t *testing.T) { + require := require.New(t) + activationTime := genesisTimestamp.Add(-1 * time.Second) coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, activationTime, 0) proVM.Set(activationTime) @@ -501,9 +498,8 @@ func TestBlockVerify_BlocksBuiltOnPostForkGenesis(t *testing.T) { postForkChild, err := proVM.BuildBlock(context.Background()) if err != nil { t.Fatalf("unexpectedly could not build block due to %s", err) - } else if _, ok := postForkChild.(*postForkBlock); !ok { - t.Fatal("expected postForkBlock") } + require.IsType(&postForkBlock{}, postForkChild) if err := postForkChild.Verify(context.Background()); err != nil { t.Fatal("post Fork blocks should verify after fork") @@ -580,6 +576,8 @@ func TestBlockAccept_PreFork_SetsLastAcceptedBlock(t *testing.T) { // ProposerBlock.Reject tests section func TestBlockReject_PreForkBlock_InnerBlockIsRejected(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, mockable.MaxTime, 0) // disable ProBlks coreBlk := &snowman.TestBlock{ TestDecidable: choices.TestDecidable{ @@ -598,10 +596,8 @@ func TestBlockReject_PreForkBlock_InnerBlockIsRejected(t *testing.T) { if err != nil { t.Fatal("could not build block") } - proBlk, ok := sb.(*preForkBlock) - if !ok { - t.Fatal("built block has not expected type") - } + require.IsType(&preForkBlock{}, sb) + proBlk := sb.(*preForkBlock) if err := proBlk.Reject(context.Background()); err != nil { t.Fatal("could not reject block") diff --git a/vms/proposervm/vm_byzantine_test.go b/vms/proposervm/vm_byzantine_test.go index d471fd700c2..4ce02a2afd6 100644 --- a/vms/proposervm/vm_byzantine_test.go +++ b/vms/proposervm/vm_byzantine_test.go @@ -11,6 +11,8 @@ import ( "testing" "time" + "github.com/stretchr/testify/require" + "github.com/ava-labs/avalanchego/database" "github.com/ava-labs/avalanchego/ids" "github.com/ava-labs/avalanchego/snow/choices" @@ -103,6 +105,8 @@ func TestInvalidByzantineProposerParent(t *testing.T) { // / \ // Y Z func TestInvalidByzantineProposerOracleParent(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -176,11 +180,8 @@ func TestInvalidByzantineProposerOracleParent(t *testing.T) { t.Fatal("could not build post fork oracle block") } - aBlock, ok := aBlockIntf.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } - + require.IsType(&postForkBlock{}, aBlockIntf) + aBlock := aBlockIntf.(*postForkBlock) opts, err := aBlock.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") @@ -330,6 +331,8 @@ func TestInvalidByzantineProposerPreForkParent(t *testing.T) { // | / // B - Y func TestBlockVerify_PostForkOption_FaultyParent(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -402,10 +405,8 @@ func TestBlockVerify_PostForkOption_FaultyParent(t *testing.T) { t.Fatal("could not build post fork oracle block") } - aBlock, ok := aBlockIntf.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, aBlockIntf) + aBlock := aBlockIntf.(*postForkBlock) opts, err := aBlock.Options(context.Background()) if err != nil { t.Fatal("could not retrieve options from post fork oracle block") diff --git a/vms/proposervm/vm_test.go b/vms/proposervm/vm_test.go index cf2423f484b..a9b45f505ab 100644 --- a/vms/proposervm/vm_test.go +++ b/vms/proposervm/vm_test.go @@ -278,6 +278,8 @@ func TestBuildBlockIsIdempotent(t *testing.T) { } func TestFirstProposerBlockIsBuiltOnTopOfGenesis(t *testing.T) { + require := require.New(t) + // setup coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) // enable ProBlks @@ -302,10 +304,8 @@ func TestFirstProposerBlockIsBuiltOnTopOfGenesis(t *testing.T) { } // checks - proBlock, ok := snowBlock.(*postForkBlock) - if !ok { - t.Fatal("proposerVM.BuildBlock() does not return a proposervm.Block") - } + require.IsType(&postForkBlock{}, snowBlock) + proBlock := snowBlock.(*postForkBlock) if proBlock.innerBlk != coreBlk { t.Fatal("different block was expected to be built") @@ -716,6 +716,8 @@ func TestTwoProBlocksWithSameParentCanBothVerify(t *testing.T) { // Pre Fork tests section func TestPreFork_Initialize(t *testing.T) { + require := require.New(t) + _, _, proVM, coreGenBlk, _ := initTestProposerVM(t, mockable.MaxTime, 0) // disable ProBlks // checks @@ -729,16 +731,15 @@ func TestPreFork_Initialize(t *testing.T) { t.Fatal("Block should be returned without calling core vm") } - if _, ok := rtvdBlk.(*preForkBlock); !ok { - t.Fatal("Block retrieved from proposerVM should be proposerBlocks") - } + require.IsType(&preForkBlock{}, rtvdBlk) if !bytes.Equal(rtvdBlk.Bytes(), coreGenBlk.Bytes()) { t.Fatal("Stored block is not genesis") } } func TestPreFork_BuildBlock(t *testing.T) { - // setup + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, mockable.MaxTime, 0) // disable ProBlks coreBlk := &snowman.TestBlock{ @@ -760,9 +761,7 @@ func TestPreFork_BuildBlock(t *testing.T) { if err != nil { t.Fatal("proposerVM could not build block") } - if _, ok := builtBlk.(*preForkBlock); !ok { - t.Fatal("Block built by proposerVM should be proposerBlocks") - } + require.IsType(&preForkBlock{}, builtBlk) if builtBlk.ID() != coreBlk.ID() { t.Fatal("unexpected built block") } @@ -784,6 +783,8 @@ func TestPreFork_BuildBlock(t *testing.T) { } func TestPreFork_ParseBlock(t *testing.T) { + require := require.New(t) + // setup coreVM, _, proVM, _, _ := initTestProposerVM(t, mockable.MaxTime, 0) // disable ProBlks @@ -805,9 +806,7 @@ func TestPreFork_ParseBlock(t *testing.T) { if err != nil { t.Fatal("Could not parse naked core block") } - if _, ok := parsedBlk.(*preForkBlock); !ok { - t.Fatal("Block parsed by proposerVM should be proposerBlocks") - } + require.IsType(&preForkBlock{}, parsedBlk) if parsedBlk.ID() != coreBlk.ID() { t.Fatal("Parsed block does not match expected block") } @@ -1774,6 +1773,8 @@ func TestTooFarAdvanced(t *testing.T) { // B(...) is B(X.opts[0]) // B(...) is C(X.opts[1]) func TestTwoOptions_OneIsAccepted(t *testing.T) { + require := require.New(t) + coreVM, _, proVM, coreGenBlk, _ := initTestProposerVM(t, time.Time{}, 0) proVM.Set(coreGenBlk.Timestamp()) @@ -1818,10 +1819,8 @@ func TestTwoOptions_OneIsAccepted(t *testing.T) { t.Fatal("could not build post fork oracle block") } - aBlock, ok := aBlockIntf.(*postForkBlock) - if !ok { - t.Fatal("expected post fork block") - } + require.IsType(&postForkBlock{}, aBlockIntf) + aBlock := aBlockIntf.(*postForkBlock) opts, err := aBlock.Options(context.Background()) if err != nil { @@ -1887,8 +1886,8 @@ func TestLaggedPChainHeight(t *testing.T) { blockIntf, err := proVM.BuildBlock(context.Background()) require.NoError(err) - block, ok := blockIntf.(*postForkBlock) - require.True(ok, "expected post fork block") + require.IsType(&postForkBlock{}, blockIntf) + block := blockIntf.(*postForkBlock) pChainHeight := block.PChainHeight() require.Equal(pChainHeight, coreGenBlk.Height()) @@ -2291,8 +2290,8 @@ func TestRejectedOptionHeightNotIndexed(t *testing.T) { aBlockIntf, err := proVM.BuildBlock(context.Background()) require.NoError(err) - aBlock, ok := aBlockIntf.(*postForkBlock) - require.True(ok) + require.IsType(&postForkBlock{}, aBlockIntf) + aBlock := aBlockIntf.(*postForkBlock) opts, err := aBlock.Options(context.Background()) require.NoError(err) diff --git a/vms/rpcchainvm/batched_vm_test.go b/vms/rpcchainvm/batched_vm_test.go index cd554af96ab..b0de0f5656e 100644 --- a/vms/rpcchainvm/batched_vm_test.go +++ b/vms/rpcchainvm/batched_vm_test.go @@ -98,8 +98,7 @@ func TestBatchedParseBlockCaching(t *testing.T) { require.NoError(err) require.Equal(blkID1, blk.ID()) - _, typeChecked := blk.(*chain.BlockWrapper) - require.True(typeChecked) + require.IsType(&chain.BlockWrapper{}, blk) // Call should cache the first block and parse the second block blks, err := vm.BatchedParseBlock(context.Background(), [][]byte{blkBytes1, blkBytes2}) @@ -108,11 +107,8 @@ func TestBatchedParseBlockCaching(t *testing.T) { require.Equal(blkID1, blks[0].ID()) require.Equal(blkID2, blks[1].ID()) - _, typeChecked = blks[0].(*chain.BlockWrapper) - require.True(typeChecked) - - _, typeChecked = blks[1].(*chain.BlockWrapper) - require.True(typeChecked) + require.IsType(&chain.BlockWrapper{}, blks[0]) + require.IsType(&chain.BlockWrapper{}, blks[1]) // Call should be fully cached and not result in a grpc call blks, err = vm.BatchedParseBlock(context.Background(), [][]byte{blkBytes1, blkBytes2}) @@ -121,9 +117,6 @@ func TestBatchedParseBlockCaching(t *testing.T) { require.Equal(blkID1, blks[0].ID()) require.Equal(blkID2, blks[1].ID()) - _, typeChecked = blks[0].(*chain.BlockWrapper) - require.True(typeChecked) - - _, typeChecked = blks[1].(*chain.BlockWrapper) - require.True(typeChecked) + require.IsType(&chain.BlockWrapper{}, blks[0]) + require.IsType(&chain.BlockWrapper{}, blks[1]) } diff --git a/vms/secp256k1fx/keychain_test.go b/vms/secp256k1fx/keychain_test.go index 56a3a5925f4..19da3e4224a 100644 --- a/vms/secp256k1fx/keychain_test.go +++ b/vms/secp256k1fx/keychain_test.go @@ -54,8 +54,8 @@ func TestKeychainAdd(t *testing.T) { addr, _ := ids.ShortFromString(addrs[0]) rsk, exists := kc.Get(addr) require.True(exists) - rsksecp, ok := rsk.(*secp256k1.PrivateKey) - require.True(ok, "Factory should have returned secp256k1r private key") + require.IsType(&secp256k1.PrivateKey{}, rsk) + rsksecp := rsk.(*secp256k1.PrivateKey) require.Equal(sk.Bytes(), rsksecp.Bytes()) addrs := kc.Addresses() @@ -157,8 +157,8 @@ func TestKeychainSpendMint(t *testing.T) { vinput, keys, err := kc.Spend(&mint, 0) require.NoError(err) - input, ok := vinput.(*Input) - require.True(ok) + require.IsType(&Input{}, vinput) + input := vinput.(*Input) require.NoError(input.Verify()) require.Equal([]uint32{0, 1}, input.SigIndices) require.Len(keys, 2) @@ -206,8 +206,8 @@ func TestKeychainSpendTransfer(t *testing.T) { vinput, keys, err := kc.Spend(&transfer, 54321) require.NoError(err) - input, ok := vinput.(*TransferInput) - require.True(ok) + require.IsType(&TransferInput{}, vinput) + input := vinput.(*TransferInput) require.NoError(input.Verify()) require.Equal(uint64(12345), input.Amount()) require.Equal([]uint32{0, 1}, input.SigIndices) diff --git a/x/merkledb/db_test.go b/x/merkledb/db_test.go index 21a9867cceb..940b1f55ee4 100644 --- a/x/merkledb/db_test.go +++ b/x/merkledb/db_test.go @@ -523,8 +523,8 @@ func TestDatabaseCommitChanges(t *testing.T) { // Make a view and inser/delete a key-value pair. view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) err = view1.Insert(context.Background(), []byte{3}, []byte{3}) require.NoError(err) err = view1.Remove(context.Background(), []byte{1}) @@ -535,14 +535,14 @@ func TestDatabaseCommitChanges(t *testing.T) { // Make a second view view2Intf, err := db.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) // Make a view atop a view view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) // view3 // | @@ -592,18 +592,18 @@ func TestDatabaseInvalidateChildrenExcept(t *testing.T) { // Create children view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) view2Intf, err := db.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) view3Intf, err := db.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) db.invalidateChildrenExcept(view1) diff --git a/x/merkledb/trie_test.go b/x/merkledb/trie_test.go index e7c75e52bf4..88ca81f26a2 100644 --- a/x/merkledb/trie_test.go +++ b/x/merkledb/trie_test.go @@ -109,8 +109,8 @@ func TestTrieViewGetPathTo(t *testing.T) { trieIntf, err := db.NewView() require.NoError(err) - trie, ok := trieIntf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, trieIntf) + trie := trieIntf.(*trieView) path, err := trie.getPathTo(newPath(nil)) require.NoError(err) @@ -549,8 +549,8 @@ func Test_Trie_CommitChanges(t *testing.T) { view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) err = view1.Insert(context.Background(), []byte{1}, []byte{1}) require.NoError(err) @@ -588,8 +588,8 @@ func Test_Trie_CommitChanges(t *testing.T) { // Make more views atop the existing one view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) err = view2.Insert(context.Background(), []byte{2}, []byte{2}) require.NoError(err) @@ -604,13 +604,13 @@ func Test_Trie_CommitChanges(t *testing.T) { view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) view4Intf, err := view2.NewView() require.NoError(err) - view4, ok := view4Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view4Intf) + view4 := view4Intf.(*trieView) // view4 // | @@ -914,8 +914,8 @@ func TestNewViewOnCommittedView(t *testing.T) { // Create a view view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) // view1 // | @@ -943,8 +943,8 @@ func TestNewViewOnCommittedView(t *testing.T) { // Create a new view on the committed view view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) // view2 // | @@ -965,8 +965,8 @@ func TestNewViewOnCommittedView(t *testing.T) { // Make another view view3Intf, err := view2.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) // view3 // | @@ -1022,14 +1022,14 @@ func Test_TrieView_NewView(t *testing.T) { // Create a view view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) // Create a view atop view1 view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) // view2 // | @@ -1049,8 +1049,8 @@ func Test_TrieView_NewView(t *testing.T) { // Make another view atop view1 view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) // view3 // | @@ -1080,19 +1080,19 @@ func TestTrieViewInvalidate(t *testing.T) { // Create a view view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) // Create 2 views atop view1 view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) // view2 view3 // | / @@ -1118,20 +1118,20 @@ func TestTrieViewMoveChildViewsToView(t *testing.T) { // Create a view view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) // Create a view atop view1 view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) // Create a view atop view2 view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) // view3 // | @@ -1158,19 +1158,19 @@ func TestTrieViewInvalidChildrenExcept(t *testing.T) { // Create a view view1Intf, err := db.NewView() require.NoError(err) - view1, ok := view1Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view1Intf) + view1 := view1Intf.(*trieView) // Create 2 views atop view1 view2Intf, err := view1.NewView() require.NoError(err) - view2, ok := view2Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view2Intf) + view2 := view2Intf.(*trieView) view3Intf, err := view1.NewView() require.NoError(err) - view3, ok := view3Intf.(*trieView) - require.True(ok) + require.IsType(&trieView{}, view3Intf) + view3 := view3Intf.(*trieView) view1.invalidateChildrenExcept(view2)