diff --git a/modules/core/04-channel/keeper/export_test.go b/modules/core/04-channel/keeper/export_test.go index a89bf2a95c8..21fad92d9dd 100644 --- a/modules/core/04-channel/keeper/export_test.go +++ b/modules/core/04-channel/keeper/export_test.go @@ -29,3 +29,8 @@ func (k Keeper) CheckForUpgradeCompatibility(ctx sdk.Context, upgradeFields, cou func (k Keeper) SyncUpgradeSequence(ctx sdk.Context, portID, channelID string, channel types.Channel, counterpartyUpgradeSequence uint64) error { return k.syncUpgradeSequence(ctx, portID, channelID, channel, counterpartyUpgradeSequence) } + +// WriteErrorReceipt is a wrapper around writeErrorReceipt to allow the function to be directly called in tests. +func (k Keeper) WriteErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeError *types.UpgradeError) { + k.writeErrorReceipt(ctx, portID, channelID, upgradeError) +} diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 27007683c44..dc9ebfa9e42 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -941,3 +941,21 @@ func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgrad return channel } + +// writeErrorReceipt will write an error receipt from the provided UpgradeError. +func (k Keeper) writeErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeError *types.UpgradeError) { + channel, found := k.GetChannel(ctx, portID, channelID) + if !found { + panic(errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID)) + } + + errorReceiptToWrite := upgradeError.GetErrorReceipt() + + existingErrorReceipt, found := k.GetUpgradeErrorReceipt(ctx, portID, channelID) + if found && existingErrorReceipt.Sequence >= errorReceiptToWrite.Sequence { + panic(errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "error receipt sequence (%d) must be greater than existing error receipt sequence (%d)", errorReceiptToWrite.Sequence, existingErrorReceipt.Sequence)) + } + + k.SetUpgradeErrorReceipt(ctx, portID, channelID, errorReceiptToWrite) + EmitErrorReceiptEvent(ctx, portID, channelID, channel, upgradeError) +} diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index a0ae26a0933..5f9b7c859ba 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -2379,3 +2379,71 @@ func (suite *KeeperTestSuite) TestChanUpgradeCrossingHelloWithHistoricalProofs() }) } } + +func (suite *KeeperTestSuite) TestWriteErrorReceipt() { + var path *ibctesting.Path + var upgradeError *types.UpgradeError + + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success", + func() {}, + nil, + }, + { + "success: existing error receipt found at a lower sequence", + func() { + // write an error sequence with a lower sequence number + previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence-1, types.ErrInvalidUpgrade) + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError.GetErrorReceipt()) + }, + nil, + }, + { + "failure: existing error receipt found at a higher sequence", + func() { + // write an error sequence with a higher sequence number + previousUpgradeError := types.NewUpgradeError(upgradeError.GetErrorReceipt().Sequence+1, types.ErrInvalidUpgrade) + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, previousUpgradeError.GetErrorReceipt()) + }, + errorsmod.Wrap(types.ErrInvalidUpgradeSequence, "error receipt sequence (10) must be greater than existing error receipt sequence (11)"), + }, + { + "failure: channel not found", + func() { + suite.chainA.DeleteKey(host.ChannelKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + errorsmod.Wrap(types.ErrChannelNotFound, "port ID (mock) channel ID (channel-0)"), + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + channelKeeper := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper + + upgradeError = types.NewUpgradeError(10, types.ErrInvalidUpgrade) + + tc.malleate() + + expPass := tc.expError == nil + if expPass { + suite.NotPanics(func() { + channelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) + } else { + suite.PanicsWithError(tc.expError.Error(), func() { + channelKeeper.WriteErrorReceipt(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgradeError) + }) + } + }) + } +}