Skip to content

Commit

Permalink
refactor: simplify testing setup for callbacks
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-axner committed Aug 8, 2023
1 parent 10db12e commit e0c2f33
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 130 deletions.
108 changes: 64 additions & 44 deletions modules/apps/callbacks/callbacks_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,59 +135,79 @@ func (s *CallbacksTestSuite) RegisterInterchainAccount(owner string) {
s.path.EndpointA.ChannelID = channelID
}

// AssertHasExecutedExpectedCallback checks if only the expected type of callback has been executed.
// AssertHasExecutedExpectedCallback checks the stateful entries and counters based on callbacktype.
// It assumes that the source chain is chainA and the destination chain is chainB.
//
// The callbackType can be one of the following:
// - types.CallbackTypeAcknowledgement
// - types.CallbackTypeWriteAcknowledgement
// - types.CallbackTypeTimeout
// - "none" (no callback should be executed)
func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallback(callbackType types.CallbackType, isSuccessful bool) {
successCount := uint64(0)
if isSuccessful {
successCount = 1
func (s *CallbacksTestSuite) AssertHasExecutedExpectedCallback(callbackType types.CallbackType, expSuccess bool) {
var expStatefulEntries uint8
if expSuccess {
// if the callback is expected to be successful,
// we expect at least one state entry
expStatefulEntries = 1
}

sourceStatefulCounter := s.chainA.GetSimApp().MockKeeper.GetStateCounter(s.chainA.GetContext())
destStatefulCounter := s.chainB.GetSimApp().MockKeeper.GetStateCounter(s.chainB.GetContext())

switch callbackType {
case "none":
s.Require().Equal(uint8(0), sourceStatefulCounter)
s.Require().Equal(uint8(0), destStatefulCounter)

case types.CallbackTypeSendPacket:
s.Require().Equal(expStatefulEntries, sourceStatefulCounter)
s.Require().Equal(uint8(0), destStatefulCounter)

case types.CallbackTypeAcknowledgement, types.CallbackTypeTimeoutPacket:
expStatefulEntries *= 2 // expect OnAcknowledgement/OnTimeout to be successful as well
s.Require().Equal(expStatefulEntries, sourceStatefulCounter)
s.Require().Equal(uint8(0), destStatefulCounter)

case types.CallbackTypeWriteAcknowledgement:
s.Require().Equal(uint8(0), sourceStatefulCounter)
s.Require().Equal(expStatefulEntries, destStatefulCounter)

default:
s.FailNow(fmt.Sprintf("invalid callback type %s", callbackType))
}

s.AssertCallbackCounters(callbackType)
}

func (s *CallbacksTestSuite) AssertCallbackCounters(callbackType types.CallbackType) {
sourceCounters := s.chainA.GetSimApp().MockKeeper.Counters
destCounters := s.chainB.GetSimApp().MockKeeper.Counters

switch callbackType {
case "none":
s.Require().Len(sourceCounters, 0)
s.Require().Len(destCounters, 0)

case types.CallbackTypeSendPacket:
s.Require().Len(sourceCounters, 1)
s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket])

case types.CallbackTypeAcknowledgement:
s.Require().Equal(successCount, s.chainA.GetSimApp().MockKeeper.AckCallbackCounter.Success)
s.Require().Equal(1-successCount, s.chainA.GetSimApp().MockKeeper.AckCallbackCounter.Failure)
s.Require().Equal(successCount, s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.Success)
s.Require().Equal(1-successCount, s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.Failure)
s.Require().Equal(uint8(2*successCount), s.chainA.GetSimApp().MockKeeper.GetStateCounter(s.chainA.GetContext()))
s.Require().Equal(uint8(0), s.chainB.GetSimApp().MockKeeper.GetStateCounter(s.chainB.GetContext()))
s.Require().True(s.chainA.GetSimApp().MockKeeper.TimeoutCallbackCounter.IsZero())
s.Require().True(s.chainB.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.IsZero())
s.Require().Len(sourceCounters, 2)
s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket])
s.Require().Equal(1, sourceCounters[types.CallbackTypeAcknowledgement])

s.Require().Len(destCounters, 0)

case types.CallbackTypeWriteAcknowledgement:
s.Require().Equal(successCount, s.chainB.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.Success)
s.Require().Equal(1-successCount, s.chainB.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.Failure)
s.Require().Equal(uint8(successCount), s.chainB.GetSimApp().MockKeeper.GetStateCounter(s.chainB.GetContext()))
s.Require().Equal(uint8(0), s.chainA.GetSimApp().MockKeeper.GetStateCounter(s.chainA.GetContext()))
s.Require().True(s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.IsZero())
s.Require().True(s.chainA.GetSimApp().MockKeeper.TimeoutCallbackCounter.IsZero())
s.Require().True(s.chainB.GetSimApp().MockKeeper.AckCallbackCounter.IsZero())
s.Require().Len(sourceCounters, 0)
s.Require().Len(destCounters, 1)
s.Require().Equal(1, destCounters[types.CallbackTypeWriteAcknowledgement])

case types.CallbackTypeTimeoutPacket:
s.Require().Equal(successCount, s.chainA.GetSimApp().MockKeeper.TimeoutCallbackCounter.Success)
s.Require().Equal(1-successCount, s.chainA.GetSimApp().MockKeeper.TimeoutCallbackCounter.Failure)
s.Require().Equal(successCount, s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.Success)
s.Require().Equal(1-successCount, s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.Failure)
s.Require().Equal(uint8(2*successCount), s.chainA.GetSimApp().MockKeeper.GetStateCounter(s.chainA.GetContext()))
s.Require().Equal(uint8(0), s.chainB.GetSimApp().MockKeeper.GetStateCounter(s.chainB.GetContext()))
s.Require().True(s.chainA.GetSimApp().MockKeeper.AckCallbackCounter.IsZero())
s.Require().True(s.chainB.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.IsZero())
case "none":
s.Require().True(s.chainA.GetSimApp().MockKeeper.AckCallbackCounter.IsZero())
s.Require().True(s.chainA.GetSimApp().MockKeeper.TimeoutCallbackCounter.IsZero())
s.Require().True(s.chainB.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.IsZero())
s.Require().True(s.chainA.GetSimApp().MockKeeper.SendPacketCallbackCounter.IsZero())
s.Require().Equal(uint8(0), s.chainA.GetSimApp().MockKeeper.GetStateCounter(s.chainA.GetContext()))
s.Require().Equal(uint8(0), s.chainB.GetSimApp().MockKeeper.GetStateCounter(s.chainB.GetContext()))
s.Require().Len(sourceCounters, 2)
s.Require().Equal(1, sourceCounters[types.CallbackTypeSendPacket])
s.Require().Equal(1, sourceCounters[types.CallbackTypeTimeoutPacket])

s.Require().Len(destCounters, 0)

default:
s.FailNow(fmt.Sprintf("invalid callback type %s", callbackType))
}
s.Require().True(s.chainB.GetSimApp().MockKeeper.AckCallbackCounter.IsZero())
s.Require().True(s.chainB.GetSimApp().MockKeeper.TimeoutCallbackCounter.IsZero())
s.Require().True(s.chainA.GetSimApp().MockKeeper.WriteAcknowledgementCallbackCounter.IsZero())
}

func TestIBCCallbacksTestSuite(t *testing.T) {
Expand Down
38 changes: 21 additions & 17 deletions modules/apps/callbacks/fee_transfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,19 +106,21 @@ func (s *CallbacksTestSuite) TestIncentivizedTransferCallbacks() {
}

for _, tc := range testCases {
s.SetupFeeTransferTest()
s.Run(tc.name, func() {
s.SetupFeeTransferTest()

fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee)
fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee)

s.ExecutePayPacketFeeMsg(fee)
preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom))
s.ExecuteTransfer(tc.transferMemo)
// we manually subtract the transfer amount from the preRelaySenderBalance because ExecuteTransfer
// also relays the packet, which will trigger the fee payments.
preRelaySenderBalance = preRelaySenderBalance.Sub(ibctesting.TestCoin)
s.ExecutePayPacketFeeMsg(fee)
preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom))
s.ExecuteTransfer(tc.transferMemo)
// we manually subtract the transfer amount from the preRelaySenderBalance because ExecuteTransfer
// also relays the packet, which will trigger the fee payments.
preRelaySenderBalance = preRelaySenderBalance.Sub(ibctesting.TestCoin)

// after incentivizing the packets
s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallbackType, tc.expSuccess, false, preRelaySenderBalance, fee)
// after incentivizing the packets
s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallbackType, tc.expSuccess, false, preRelaySenderBalance, fee)
})
}
}

Expand Down Expand Up @@ -174,16 +176,18 @@ func (s *CallbacksTestSuite) TestIncentivizedTransferTimeoutCallbacks() {
}

for _, tc := range testCases {
s.SetupFeeTransferTest()
s.Run(tc.name, func() {
s.SetupFeeTransferTest()

fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee)
fee := feetypes.NewFee(defaultRecvFee, defaultAckFee, defaultTimeoutFee)

s.ExecutePayPacketFeeMsg(fee)
preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom))
s.ExecuteTransferTimeout(tc.transferMemo, 1)
s.ExecutePayPacketFeeMsg(fee)
preRelaySenderBalance := sdk.NewCoins(s.chainA.GetSimApp().BankKeeper.GetBalance(s.chainA.GetContext(), s.chainA.SenderAccount.GetAddress(), ibctesting.TestCoin.Denom))
s.ExecuteTransferTimeout(tc.transferMemo, 1)

// after incentivizing the packets
s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallbackType, tc.expSuccess, true, preRelaySenderBalance, fee)
// after incentivizing the packets
s.AssertHasExecutedExpectedCallbackWithFee(tc.expCallbackType, tc.expSuccess, true, preRelaySenderBalance, fee)
})
}
}

Expand Down
16 changes: 10 additions & 6 deletions modules/apps/callbacks/ica_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,12 @@ func (s *CallbacksTestSuite) TestICACallbacks() {
}

for _, tc := range testCases {
icaAddr := s.SetupICATest()
s.Run(tc.name, func() {
icaAddr := s.SetupICATest()

s.ExecuteICATx(icaAddr, tc.icaMemo, 1)
s.AssertHasExecutedExpectedCallback(tc.expCallbackType, tc.expSuccess)
s.ExecuteICATx(icaAddr, tc.icaMemo, 1)
s.AssertHasExecutedExpectedCallback(tc.expCallbackType, tc.expSuccess)
})
}
}

Expand Down Expand Up @@ -167,10 +169,12 @@ func (s *CallbacksTestSuite) TestICATimeoutCallbacks() {
}

for _, tc := range testCases {
icaAddr := s.SetupICATest()
s.Run(tc.name, func() {
icaAddr := s.SetupICATest()

s.ExecuteICATimeout(icaAddr, tc.icaMemo, 1)
s.AssertHasExecutedExpectedCallback(tc.expCallbackType, tc.expSuccess)
s.ExecuteICATimeout(icaAddr, tc.icaMemo, 1)
s.AssertHasExecutedExpectedCallback(tc.expCallbackType, tc.expSuccess)
})
}
}

Expand Down
41 changes: 21 additions & 20 deletions testing/mock/keeper.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
clienttypes "github.com/cosmos/ibc-go/v7/modules/core/02-client/types"
channeltypes "github.com/cosmos/ibc-go/v7/modules/core/04-channel/types"
ibcexported "github.com/cosmos/ibc-go/v7/modules/core/exported"
"github.com/cosmos/ibc-go/v7/testing/mock/types"
)

// MockKeeper implements callbacktypes.ContractKeeper
Expand All @@ -28,12 +27,16 @@ type Keeper struct {

// This is a mock keeper used for testing. It is not wired up to any modules.
// It implements the interface functions expected by the ibccallbacks middleware
// so that it can be tested with simapp.
// so that it can be tested with simapp. The keeper is responsible for tracking
// two metrics:
// - number of callbacks called per callback type
// - stateful entry attempts
//
// The counter for callbacks allows us to ensure the correct callbacks were routed to
// and the stateful entries allows us to track state reversals or reverted state upon
// contract execution failure or out of gas errors.
type ContractKeeper struct {
SendPacketCallbackCounter *types.CallbackCounter
AckCallbackCounter *types.CallbackCounter
TimeoutCallbackCounter *types.CallbackCounter
WriteAcknowledgementCallbackCounter *types.CallbackCounter
Counters map[callbacktypes.CallbackType]int
}

// SetStateCounter sets the stateful callback counter in state.
Expand Down Expand Up @@ -66,11 +69,7 @@ func NewMockKeeper(key storetypes.StoreKey) Keeper {
return Keeper{
key: key,
ContractKeeper: ContractKeeper{
SendPacketCallbackCounter: types.NewCallbackCounter(),
AckCallbackCounter: types.NewCallbackCounter(),
TimeoutCallbackCounter: types.NewCallbackCounter(),
WriteAcknowledgementCallbackCounter: types.NewCallbackCounter(),
},
Counters: make(map[callbacktypes.CallbackType]int)},
}
}

Expand All @@ -88,7 +87,7 @@ func (k Keeper) IBCSendPacketCallback(
contractAddress,
packetSenderAddress string,
) error {
return k.processMockCallback(ctx, callbacktypes.CallbackTypeSendPacket, k.SendPacketCallbackCounter, packetSenderAddress)
return k.processMockCallback(ctx, callbacktypes.CallbackTypeSendPacket, packetSenderAddress)
}

// IBCOnAcknowledgementPacketCallback returns nil if the gas meter has greater than
Expand All @@ -103,7 +102,7 @@ func (k Keeper) IBCOnAcknowledgementPacketCallback(
contractAddress,
packetSenderAddress string,
) error {
return k.processMockCallback(ctx, callbacktypes.CallbackTypeAcknowledgement, k.AckCallbackCounter, packetSenderAddress)
return k.processMockCallback(ctx, callbacktypes.CallbackTypeAcknowledgement, packetSenderAddress)
}

// IBCOnTimeoutPacketCallback returns nil if the gas meter has greater than
Expand All @@ -117,7 +116,7 @@ func (k Keeper) IBCOnTimeoutPacketCallback(
contractAddress,
packetSenderAddress string,
) error {
return k.processMockCallback(ctx, callbacktypes.CallbackTypeTimeoutPacket, k.TimeoutCallbackCounter, packetSenderAddress)
return k.processMockCallback(ctx, callbacktypes.CallbackTypeTimeoutPacket, packetSenderAddress)
}

// IBCWriteAcknowledgementCallback returns nil if the gas meter has greater than
Expand All @@ -130,7 +129,7 @@ func (k Keeper) IBCWriteAcknowledgementCallback(
ack ibcexported.Acknowledgement,
contractAddress string,
) error {
return k.processMockCallback(ctx, callbacktypes.CallbackTypeWriteAcknowledgement, k.WriteAcknowledgementCallbackCounter, "")
return k.processMockCallback(ctx, callbacktypes.CallbackTypeWriteAcknowledgement, "")
}

// processMockCallback returns nil if the gas meter has greater than or equal to 500000 gas remaining.
Expand All @@ -139,29 +138,31 @@ func (k Keeper) IBCWriteAcknowledgementCallback(
func (k Keeper) processMockCallback(
ctx sdk.Context,
callbackType callbacktypes.CallbackType,
callbackCounter *types.CallbackCounter,
authAddress string,
) error {
gasRemaining := ctx.GasMeter().GasRemaining()

// increment stateful entries, if the callbacks module handler
// reverts state, we can check by querying for the counter
// currently stored.
k.IncrementStatefulCounter(ctx)

// increment callback execution attempts
k.Counters[callbackType]++

if gasRemaining < 400000 {
callbackCounter.IncrementFailure()
// consume gas will panic since we attempt to consume 500_000 gas, for tests
ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback panic", callbackType))
} else if gasRemaining < 500000 {
callbackCounter.IncrementFailure()
ctx.GasMeter().ConsumeGas(gasRemaining, fmt.Sprintf("mock %s callback failure", callbackType))
return MockApplicationCallbackError
}

if authAddress == MockCallbackUnauthorizedAddress {
callbackCounter.IncrementFailure()
ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback unauthorized", callbackType))
return MockApplicationCallbackError
}

callbackCounter.IncrementSuccess()
ctx.GasMeter().ConsumeGas(500000, fmt.Sprintf("mock %s callback success", callbackType))
return nil
}
43 changes: 0 additions & 43 deletions testing/mock/types/callback_counter.go

This file was deleted.

0 comments on commit e0c2f33

Please sign in to comment.