Skip to content

Commit

Permalink
Make sure that the error receipt we write always has a sequence great…
Browse files Browse the repository at this point in the history
…er than the existing one. (#5237)
  • Loading branch information
chatton authored Dec 13, 2023
1 parent a99c84c commit 0323942
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
5 changes: 5 additions & 0 deletions modules/core/04-channel/keeper/export_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,8 @@ func (k Keeper) CheckForUpgradeCompatibility(ctx sdk.Context, upgradeFields, cou
func (k Keeper) SyncUpgradeSequence(ctx sdk.Context, portID, channelID string, channel types.Channel, counterpartyUpgradeSequence uint64) error {
return k.syncUpgradeSequence(ctx, portID, channelID, channel, counterpartyUpgradeSequence)
}

// WriteErrorReceipt is a wrapper around writeErrorReceipt to allow the function to be directly called in tests.
func (k Keeper) WriteErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeError *types.UpgradeError) {
k.writeErrorReceipt(ctx, portID, channelID, upgradeError)
}
18 changes: 18 additions & 0 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -941,3 +941,21 @@ func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgrad

return channel
}

// writeErrorReceipt will write an error receipt from the provided UpgradeError.
func (k Keeper) writeErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeError *types.UpgradeError) {
channel, found := k.GetChannel(ctx, portID, channelID)
if !found {
panic(errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID))
}

errorReceiptToWrite := upgradeError.GetErrorReceipt()

existingErrorReceipt, found := k.GetUpgradeErrorReceipt(ctx, portID, channelID)
if found && existingErrorReceipt.Sequence >= errorReceiptToWrite.Sequence {
panic(errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than existing error receipt sequence (%d)", errorReceiptToWrite.Sequence, existingErrorReceipt.Sequence))
}

k.SetUpgradeErrorReceipt(ctx, portID, channelID, errorReceiptToWrite)
EmitErrorReceiptEvent(ctx, portID, channelID, channel, upgradeError)
}
68 changes: 68 additions & 0 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2379,3 +2379,71 @@ func (suite *KeeperTestSuite) TestChanUpgradeCrossingHelloWithHistoricalProofs()
})
}
}

func (suite *KeeperTestSuite) TestWriteErrorReceipt() {
var path *ibctesting.Path
var upgradeError *types.UpgradeError

testCases := []struct {
name string
malleate func()
expError error
}{
{
"success",
func() {},
nil,
},
{
"success: existing error receipt found at a lower sequence",
func() {
// write an error sequence with a lower sequence number
previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence-1, types.ErrInvalidUpgrade)
suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError.GetErrorReceipt())
},
nil,
},
{
"failure: existing error receipt found at a higher sequence",
func() {
// write an error sequence with a higher sequence number
previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence+1, types.ErrInvalidUpgrade)
suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError.GetErrorReceipt())
},
errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error receipt sequence (10) must be greater than existing error receipt sequence (11)"),
},
{
"failure: channel not found",
func() {
suite.chainA.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID))
},
errorsmod.Wrap(types.ErrChannelNotFound, "port ID (mock) channel ID (channel-0)"),
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
suite.SetupTest()
path = ibctesting.NewPath(suite.chainA, suite.chainB)
suite.coordinator.Setup(path)

channelKeeper := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper

upgradeError = types.NewUpgradeError(10, types.ErrInvalidUpgrade)

tc.malleate()

expPass := tc.expError == nil
if expPass {
suite.NotPanics(func() {
channelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError)
})
} else {
suite.PanicsWithError(tc.expError.Error(), func() {
channelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError)
})
}
})
}
}

0 comments on commit 0323942

Please sign in to comment.