Skip to content

Commit

Permalink
Use require.IsType for type assertions in tests (ava-labs#1458)
Browse files Browse the repository at this point in the history
  • Loading branch information
dhrubabasu committed May 3, 2023
1 parent eb8b52a commit b3a07d8
Show file tree
Hide file tree
Showing 41 changed files with 379 additions and 438 deletions.
4 changes: 2 additions & 2 deletions api/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ func TestNewTokenHappyPath(t *testing.T) {
})
require.NoError(t, err, "couldn't parse new token")

claims, ok := token.Claims.(*endpointClaims)
require.True(t, ok, "expected auth token's claims to be type endpointClaims but is different type")
require.IsType(t, &endpointClaims{}, token.Claims)
claims := token.Claims.(*endpointClaims)
require.ElementsMatch(t, endpoints, claims.Endpoints, "token has wrong endpoint claims")

shouldExpireAt := jwt.NewNumericDate(now.Add(defaultTokenLifespan))
Expand Down
18 changes: 6 additions & 12 deletions database/manager/manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -306,12 +306,9 @@ func TestMeterDBManager(t *testing.T) {
dbs := manager.GetDatabases()
require.Len(dbs, 3)

_, ok := dbs[0].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[1].Database.(*meterdb.Database)
require.False(ok)
_, ok = dbs[2].Database.(*meterdb.Database)
require.False(ok)
require.IsType(&meterdb.Database{}, dbs[0].Database)
require.IsType(&memdb.Database{}, dbs[1].Database)
require.IsType(&memdb.Database{}, dbs[2].Database)

// Confirm that the error from a name conflict is handled correctly
_, err = m.NewMeterDBManager("", registry)
Expand Down Expand Up @@ -355,12 +352,9 @@ func TestCompleteMeterDBManager(t *testing.T) {
dbs := manager.GetDatabases()
require.Len(dbs, 3)

_, ok := dbs[0].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[1].Database.(*meterdb.Database)
require.True(ok)
_, ok = dbs[2].Database.(*meterdb.Database)
require.True(ok)
require.IsType(&meterdb.Database{}, dbs[0].Database)
require.IsType(&meterdb.Database{}, dbs[1].Database)
require.IsType(&meterdb.Database{}, dbs[2].Database)

// Confirm that the error from a name conflict is handled correctly
_, err = m.NewCompleteMeterDBManager("", registry)
Expand Down
39 changes: 19 additions & 20 deletions indexer/indexer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,8 @@ func TestNewIndexer(t *testing.T) {

idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.NotNil(idxr.codec)
require.NotNil(idxr.log)
require.NotNil(idxr.db)
Expand Down Expand Up @@ -118,8 +118,8 @@ func TestMarkHasRunAndShutdown(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.True(idxr.hasRunBefore)
require.NoError(idxr.Close())
shutdown.Wait()
Expand Down Expand Up @@ -150,8 +150,8 @@ func TestIndexer(t *testing.T) {
// Create indexer
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
now := time.Now()
idxr.clock.Set(now)

Expand Down Expand Up @@ -232,10 +232,10 @@ func TestIndexer(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
now = time.Now()
idxr.clock.Set(now)
require.True(ok)
require.Len(idxr.blockIndices, 0)
require.Len(idxr.txIndices, 0)
require.Len(idxr.vtxIndices, 0)
Expand Down Expand Up @@ -389,8 +389,8 @@ func TestIndexer(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
idxr.RegisterChain("chain1", chain1Ctx, chainVM)
idxr.RegisterChain("chain2", chain2Ctx, dagVM)

Expand Down Expand Up @@ -427,8 +427,8 @@ func TestIncompleteIndex(t *testing.T) {
}
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)
require.False(idxr.indexingEnabled)

// Register a chain
Expand All @@ -454,8 +454,8 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
require.True(idxr.indexingEnabled)

// Register the chain again. Should die due to incomplete index.
Expand All @@ -470,8 +470,8 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
idxr, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr = idxrIntf.(*indexer)
require.True(idxr.allowIncompleteIndex)

// Register the chain again. Should be OK
Expand All @@ -486,8 +486,7 @@ func TestIncompleteIndex(t *testing.T) {
config.DB = versiondb.New(baseDB)
idxrIntf, err = NewIndexer(config)
require.NoError(err)
_, ok = idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
}

// Ensure we only index chains in the primary network
Expand All @@ -513,8 +512,8 @@ func TestIgnoreNonDefaultChains(t *testing.T) {
// Create indexer
idxrIntf, err := NewIndexer(config)
require.NoError(err)
idxr, ok := idxrIntf.(*indexer)
require.True(ok)
require.IsType(&indexer{}, idxrIntf)
idxr := idxrIntf.(*indexer)

// Assert state is right
chain1Ctx := snow.DefaultConsensusContextTest()
Expand Down
52 changes: 26 additions & 26 deletions message/inbound_msg_builder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetStateSummaryFrontier)
require.True(ok)
require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
},
Expand All @@ -87,8 +87,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(StateSummaryFrontierOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.StateSummaryFrontier)
require.True(ok)
require.IsType(&p2p.StateSummaryFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.StateSummaryFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(summary, innerMsg.Summary)
Expand All @@ -114,8 +114,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAcceptedStateSummary)
require.True(ok)
require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(heights, innerMsg.Heights)
Expand All @@ -137,8 +137,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedStateSummaryOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AcceptedStateSummary)
require.True(ok)
require.IsType(&p2p.AcceptedStateSummary{}, msg.Message())
innerMsg := msg.Message().(*p2p.AcceptedStateSummary)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
summaryIDsBytes := make([][]byte, len(summaryIDs))
Expand Down Expand Up @@ -169,8 +169,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAcceptedFrontier)
require.True(ok)
require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAcceptedFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(engineType, innerMsg.EngineType)
Expand All @@ -192,8 +192,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedFrontierOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AcceptedFrontier)
require.True(ok)
require.IsType(&p2p.AcceptedFrontier{}, msg.Message())
innerMsg := msg.Message().(*p2p.AcceptedFrontier)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -225,8 +225,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.GetAccepted)
require.True(ok)
require.IsType(&p2p.GetAccepted{}, msg.Message())
innerMsg := msg.Message().(*p2p.GetAccepted)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(engineType, innerMsg.EngineType)
Expand All @@ -248,8 +248,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AcceptedOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.Accepted)
require.True(ok)
require.IsType(&p2p.Accepted{}, msg.Message())
innerMsg := msg.Message().(*p2p.Accepted)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -281,8 +281,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.PushQuery)
require.True(ok)
require.IsType(&p2p.PushQuery{}, msg.Message())
innerMsg := msg.Message().(*p2p.PushQuery)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(container, innerMsg.Container)
Expand Down Expand Up @@ -310,8 +310,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.PullQuery)
require.True(ok)
require.IsType(&p2p.PullQuery{}, msg.Message())
innerMsg := msg.Message().(*p2p.PullQuery)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(containerIDs[0][:], innerMsg.ContainerId)
Expand All @@ -335,8 +335,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(ChitsOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.Chits)
require.True(ok)
require.IsType(&p2p.Chits{}, msg.Message())
innerMsg := msg.Message().(*p2p.Chits)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
containerIDsBytes := make([][]byte, len(containerIDs))
Expand Down Expand Up @@ -373,8 +373,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(nodeID, msg.NodeID())
require.False(msg.Expiration().Before(start.Add(deadline)))
require.False(end.Add(deadline).Before(msg.Expiration()))
innerMsg, ok := msg.Message().(*p2p.AppRequest)
require.True(ok)
require.IsType(&p2p.AppRequest{}, msg.Message())
innerMsg := msg.Message().(*p2p.AppRequest)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(appBytes, innerMsg.AppBytes)
Expand All @@ -396,8 +396,8 @@ func TestInboundMsgBuilder(t *testing.T) {
require.Equal(AppResponseOp, msg.Op())
require.Equal(nodeID, msg.NodeID())
require.Equal(mockable.MaxTime, msg.Expiration())
innerMsg, ok := msg.Message().(*p2p.AppResponse)
require.True(ok)
require.IsType(&p2p.AppResponse{}, msg.Message())
innerMsg := msg.Message().(*p2p.AppResponse)
require.Equal(chainID[:], innerMsg.ChainId)
require.Equal(requestID, innerMsg.RequestId)
require.Equal(appBytes, innerMsg.AppBytes)
Expand Down
4 changes: 2 additions & 2 deletions message/messages_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -872,7 +872,7 @@ func TestNilInboundMessage(t *testing.T) {
parsedMsg, err := mb.parseInbound(msgBytes, ids.EmptyNodeID, func() {})
require.NoError(err)

pingMsg, ok := parsedMsg.message.(*p2p.Ping)
require.True(ok)
require.IsType(&p2p.Ping{}, parsedMsg.message)
pingMsg := parsedMsg.message.(*p2p.Ping)
require.NotNil(pingMsg)
}
4 changes: 2 additions & 2 deletions network/throttling/bandwidth_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ func TestBandwidthThrottler(t *testing.T) {
}
throttlerIntf, err := newBandwidthThrottler(logging.NoLog{}, "", prometheus.NewRegistry(), config)
require.NoError(err)
throttler, ok := throttlerIntf.(*bandwidthThrottlerImpl)
require.True(ok)
require.IsType(&bandwidthThrottlerImpl{}, throttlerIntf)
throttler := throttlerIntf.(*bandwidthThrottlerImpl)
require.NotNil(throttler.log)
require.NotNil(throttler.limiters)
require.Equal(config.RefillRate, throttler.RefillRate)
Expand Down
4 changes: 2 additions & 2 deletions network/throttling/inbound_resource_throttler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ func TestNewSystemThrottler(t *testing.T) {
targeter := tracker.NewMockTargeter(ctrl)
throttlerIntf, err := NewSystemThrottler("", reg, config, cpuTracker, targeter)
require.NoError(err)
throttler, ok := throttlerIntf.(*systemThrottler)
require.True(ok)
require.IsType(&systemThrottler{}, throttlerIntf)
throttler := throttlerIntf.(*systemThrottler)
require.Equal(clock, config.Clock)
require.Equal(time.Second, config.MaxRecheckDelay)
require.Equal(cpuTracker, throttler.tracker)
Expand Down
10 changes: 6 additions & 4 deletions snow/consensus/snowball/unary_snowball_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package snowball

import (
"testing"

"github.com/stretchr/testify/require"
)

func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessfulPolls, expectedConfidence int, expectedFinalized bool) {
Expand All @@ -18,6 +20,8 @@ func UnarySnowballStateTest(t *testing.T, sb *unarySnowball, expectedNumSuccessf
}

func TestUnarySnowball(t *testing.T) {
require := require.New(t)

beta := 2

sb := &unarySnowball{}
Expand All @@ -33,10 +37,8 @@ func TestUnarySnowball(t *testing.T) {
UnarySnowballStateTest(t, sb, 2, 1, false)

sbCloneIntf := sb.Clone()
sbClone, ok := sbCloneIntf.(*unarySnowball)
if !ok {
t.Fatalf("Unexpected clone type")
}
require.IsType(&unarySnowball{}, sbCloneIntf)
sbClone := sbCloneIntf.(*unarySnowball)

UnarySnowballStateTest(t, sbClone, 2, 1, false)

Expand Down
10 changes: 6 additions & 4 deletions snow/consensus/snowball/unary_snowflake_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ package snowball

import (
"testing"

"github.com/stretchr/testify/require"
)

func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidence int, expectedFinalized bool) {
Expand All @@ -16,6 +18,8 @@ func UnarySnowflakeStateTest(t *testing.T, sf *unarySnowflake, expectedConfidenc
}

func TestUnarySnowflake(t *testing.T) {
require := require.New(t)

beta := 2

sf := &unarySnowflake{}
Expand All @@ -31,10 +35,8 @@ func TestUnarySnowflake(t *testing.T) {
UnarySnowflakeStateTest(t, sf, 1, false)

sfCloneIntf := sf.Clone()
sfClone, ok := sfCloneIntf.(*unarySnowflake)
if !ok {
t.Fatalf("Unexpected clone type")
}
require.IsType(&unarySnowflake{}, sfCloneIntf)
sfClone := sfCloneIntf.(*unarySnowflake)

UnarySnowflakeStateTest(t, sfClone, 1, false)

Expand Down
Loading

0 comments on commit b3a07d8

Please sign in to comment.