diff --git a/modules/core/04-channel/keeper/events.go b/modules/core/04-channel/keeper/events.go index c6b98f456b6..f91a0c75b21 100644 --- a/modules/core/04-channel/keeper/events.go +++ b/modules/core/04-channel/keeper/events.go @@ -405,7 +405,7 @@ func emitChannelUpgradeTimeoutEvent(ctx sdk.Context, portID string, channelID st } // emitErrorReceiptEvent emits an error receipt event -func emitErrorReceiptEvent(ctx sdk.Context, portID string, channelID string, currentChannel types.Channel, upgradeFields types.UpgradeFields, err error) { +func emitErrorReceiptEvent(ctx sdk.Context, portID string, channelID string, currentChannel types.Channel, err error) { ctx.EventManager().EmitEvents(sdk.Events{ sdk.NewEvent( types.EventTypeChannelUpgradeInit, // TODO(bug): use correct const value @@ -413,9 +413,6 @@ func emitErrorReceiptEvent(ctx sdk.Context, portID string, channelID string, cur sdk.NewAttribute(types.AttributeKeyChannelID, channelID), sdk.NewAttribute(types.AttributeCounterpartyPortID, currentChannel.Counterparty.PortId), sdk.NewAttribute(types.AttributeCounterpartyChannelID, currentChannel.Counterparty.ChannelId), - sdk.NewAttribute(types.AttributeKeyUpgradeConnectionHops, upgradeFields.ConnectionHops[0]), - sdk.NewAttribute(types.AttributeKeyUpgradeVersion, upgradeFields.Version), - sdk.NewAttribute(types.AttributeKeyUpgradeOrdering, upgradeFields.Ordering.String()), sdk.NewAttribute(types.AttributeKeyUpgradeSequence, fmt.Sprintf("%d", currentChannel.UpgradeSequence)), sdk.NewAttribute(types.AttributeKeyUpgradeErrorReceipt, err.Error()), ), diff --git a/modules/core/04-channel/keeper/upgrade.go b/modules/core/04-channel/keeper/upgrade.go index 921363efef2..8cfea3f24ff 100644 --- a/modules/core/04-channel/keeper/upgrade.go +++ b/modules/core/04-channel/keeper/upgrade.go @@ -621,8 +621,7 @@ func (k Keeper) WriteUpgradeCancelChannel(ctx sdk.Context, portID, channelID str previousState := channel.State - k.SetUpgradeErrorReceipt(ctx, portID, channelID, errorReceipt) - channel = k.restoreChannel(ctx, portID, channelID, errorReceipt.Sequence, channel) + channel = k.restoreChannel(ctx, portID, channelID, errorReceipt.Sequence, channel, types.NewUpgradeError(errorReceipt.Sequence, types.ErrInvalidUpgrade)) k.Logger(ctx).Info("channel state updated", "port-id", portID, "channel-id", channelID, "previous-state", previousState, "new-state", types.OPEN.String()) emitChannelUpgradeCancelEvent(ctx, portID, channelID, channel, upgrade) @@ -634,9 +633,7 @@ func (k Keeper) ChanUpgradeTimeout( ctx sdk.Context, portID, channelID string, counterpartyChannel types.Channel, - prevErrorReceipt *types.ErrorReceipt, - proofCounterpartyChannel, - proofErrorReceipt []byte, + proofCounterpartyChannel []byte, proofHeight exported.Height, ) error { channel, found := k.GetChannel(ctx, portID, channelID) @@ -644,8 +641,8 @@ func (k Keeper) ChanUpgradeTimeout( return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - if channel.State != types.INITUPGRADE { - return errorsmod.Wrapf(types.ErrInvalidChannelState, "channel state is not INITUPGRADE (got %s)", channel.State) + if !collections.Contains(channel.State, []types.State{types.STATE_FLUSHING, types.STATE_FLUSHCOMPLETE}) { + return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.STATE_FLUSHING, types.STATE_FLUSHCOMPLETE, channel.State) } upgrade, found := k.GetUpgrade(ctx, portID, channelID) @@ -668,23 +665,44 @@ func (k Keeper) ChanUpgradeTimeout( ) } - // proof must be from a height after timeout has elapsed. Either timeoutHeight or timeoutTimestamp must be defined. - // if timeoutHeight is defined and proof is from before timeout height, abort transaction proofTimestamp, err := k.connectionKeeper.GetTimestampAtHeight(ctx, connection, proofHeight) if err != nil { return err } - timeout := upgrade.Timeout - proofHeightIsInvalid := timeout.Height.IsZero() || proofHeight.LT(timeout.Height) - proofTimestampIsInvalid := timeout.Timestamp == 0 || proofTimestamp < timeout.Timestamp - if proofHeightIsInvalid && proofTimestampIsInvalid { - return errorsmod.Wrap(types.ErrInvalidUpgradeTimeout, "timeout has not yet passed on counterparty chain") + // proof must be from a height after timeout has elapsed. Either timeoutHeight or timeoutTimestamp must be defined. + // if timeoutHeight is defined and proof is from before timeout height, abort transaction + timeoutHeight := upgrade.Timeout.Height + timeoutTimeStamp := upgrade.Timeout.Timestamp + if (timeoutHeight.IsZero() || proofHeight.LT(timeoutHeight)) && + (timeoutTimeStamp == 0 || proofTimestamp < timeoutTimeStamp) { + return errorsmod.Wrap(types.ErrInvalidUpgradeTimeout, "upgrade timeout has not been reached for height or timestamp") + } + + // counterparty channel must be proved to still be in OPEN state or FLUSHING state. + if !collections.Contains(counterpartyChannel.State, []types.State{types.OPEN, types.STATE_FLUSHING}) { + return errorsmod.Wrapf(types.ErrInvalidCounterparty, "expected one of [%s, %s], got %s", types.OPEN, types.STATE_FLUSHING, counterpartyChannel.State) + } + + if counterpartyChannel.State == types.OPEN { + upgradeConnection, found := k.connectionKeeper.GetConnection(ctx, upgrade.Fields.ConnectionHops[0]) + if !found { + return errorsmod.Wrap( + connectiontypes.ErrConnectionNotFound, + upgrade.Fields.ConnectionHops[0], + ) + } + counterpartyHops := []string{upgradeConnection.GetCounterparty().GetConnectionID()} + + upgradeAlreadyComplete := upgrade.Fields.Version == counterpartyChannel.Version && upgrade.Fields.Ordering == counterpartyChannel.Ordering && upgrade.Fields.ConnectionHops[0] == counterpartyHops[0] + if upgradeAlreadyComplete { + // counterparty has already successfully upgraded so we cannot timeout + return errorsmod.Wrap(types.ErrUpgradeTimeoutFailed, "counterparty channel is already upgraded") + } } - // counterparty channel must be proved to still be in OPEN state or INITUPGRADE state (crossing hellos) - if !collections.Contains(counterpartyChannel.State, []types.State{types.OPEN, types.INITUPGRADE}) { - return errorsmod.Wrapf(types.ErrInvalidChannelState, "expected one of [%s, %s], got %s", types.OPEN, types.INITUPGRADE, counterpartyChannel.State) + if counterpartyChannel.UpgradeSequence < channel.UpgradeSequence { + return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "counterparty channel upgrade sequence (%d) must be greater than or equal to current upgrade sequence (%d)", counterpartyChannel.UpgradeSequence, channel.UpgradeSequence) } // verify the counterparty channel state @@ -699,38 +717,6 @@ func (k Keeper) ChanUpgradeTimeout( return errorsmod.Wrap(err, "failed to verify counterparty channel state") } - // Error receipt passed in is either nil or it is a stale error receipt from a previous upgrade - if prevErrorReceipt == nil { - if err := k.connectionKeeper.VerifyChannelUpgradeErrorAbsence( - ctx, - channel.Counterparty.PortId, channel.Counterparty.ChannelId, - connection, - proofErrorReceipt, - proofHeight, - ); err != nil { - return errorsmod.Wrap(err, "failed to verify absence of counterparty channel upgrade error receipt") - } - - return nil - } - // timeout for this sequence can only succeed if the error receipt written into the error path on the counterparty - // was for a previous sequence by the timeout deadline. - upgradeSequence := channel.UpgradeSequence - if upgradeSequence <= prevErrorReceipt.Sequence { - return errorsmod.Wrapf(types.ErrInvalidUpgradeSequence, "previous counterparty error receipt sequence is greater than or equal to our current upgrade sequence: %d > %d", prevErrorReceipt.Sequence, upgradeSequence) - } - - if err := k.connectionKeeper.VerifyChannelUpgradeError( - ctx, - channel.Counterparty.PortId, channel.Counterparty.ChannelId, - connection, - *prevErrorReceipt, - proofErrorReceipt, - proofHeight, - ); err != nil { - return errorsmod.Wrap(err, "failed to verify counterparty channel upgrade error receipt") - } - return nil } @@ -753,7 +739,7 @@ func (k Keeper) WriteUpgradeTimeoutChannel( panic(fmt.Sprintf("could not find existing upgrade when cancelling channel upgrade, channelID: %s, portID: %s", channelID, portID)) } - channel = k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel) + channel = k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel, types.NewUpgradeError(channel.UpgradeSequence, types.ErrUpgradeTimeout)) k.Logger(ctx).Info("channel state restored", "port-id", portID, "channel-id", channelID) emitChannelUpgradeTimeoutEvent(ctx, portID, channelID, channel, upgrade) @@ -924,20 +910,11 @@ func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err erro return errorsmod.Wrap(types.ErrInvalidUpgradeError, "cannot abort upgrade handshake with nil error") } - upgrade, found := k.GetUpgrade(ctx, portID, channelID) - if !found { - return errorsmod.Wrapf(types.ErrUpgradeNotFound, "port ID (%s) channel ID (%s)", portID, channelID) - } - channel, found := k.GetChannel(ctx, portID, channelID) if !found { return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } - // the channel upgrade sequence has already been updated in ChannelUpgradeTry, so we can pass - // its updated value. - k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel) - // in the case of application callbacks, the error may not be an upgrade error. // in this case we need to construct one in order to write the error receipt. upgradeError, ok := err.(*types.UpgradeError) @@ -945,15 +922,14 @@ func (k Keeper) abortUpgrade(ctx sdk.Context, portID, channelID string, err erro upgradeError = types.NewUpgradeError(channel.UpgradeSequence, err) } - if err := k.WriteErrorReceipt(ctx, portID, channelID, upgrade.Fields, upgradeError); err != nil { - return err - } - + // the channel upgrade sequence has already been updated in ChannelUpgradeTry, so we can pass + // its updated value. + k.restoreChannel(ctx, portID, channelID, channel.UpgradeSequence, channel, upgradeError) return nil } // restoreChannel will restore the channel state and flush status to their pre-upgrade state so that upgrade is aborted. -func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, channel types.Channel) types.Channel { +func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgradeSequence uint64, channel types.Channel, err *types.UpgradeError) types.Channel { channel.State = types.OPEN channel.UpgradeSequence = upgradeSequence @@ -961,17 +937,20 @@ func (k Keeper) restoreChannel(ctx sdk.Context, portID, channelID string, upgrad // delete state associated with upgrade which is no longer required. k.deleteUpgradeInfo(ctx, portID, channelID) + + _ = k.WriteErrorReceipt(ctx, portID, channelID, err) + return channel } // WriteErrorReceipt will write an error receipt from the provided UpgradeError. -func (k Keeper) WriteErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeFields types.UpgradeFields, upgradeError *types.UpgradeError) error { +func (k Keeper) WriteErrorReceipt(ctx sdk.Context, portID, channelID string, upgradeError *types.UpgradeError) error { channel, found := k.GetChannel(ctx, portID, channelID) if !found { return errorsmod.Wrapf(types.ErrChannelNotFound, "port ID (%s) channel ID (%s)", portID, channelID) } k.SetUpgradeErrorReceipt(ctx, portID, channelID, upgradeError.GetErrorReceipt()) - emitErrorReceiptEvent(ctx, portID, channelID, channel, upgradeFields, upgradeError) + emitErrorReceiptEvent(ctx, portID, channelID, channel, upgradeError) return nil } diff --git a/modules/core/04-channel/keeper/upgrade_test.go b/modules/core/04-channel/keeper/upgrade_test.go index 184805a4bbf..426767b8a77 100644 --- a/modules/core/04-channel/keeper/upgrade_test.go +++ b/modules/core/04-channel/keeper/upgrade_test.go @@ -2,6 +2,7 @@ package keeper_test import ( "fmt" + "math" errorsmod "cosmossdk.io/errors" @@ -1406,206 +1407,242 @@ func (suite *KeeperTestSuite) TestWriteUpgradeCancelChannel() { } } -// func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { -// var ( -// path *ibctesting.Path -// errReceipt *types.ErrorReceipt -// proofHeight exported.Height -// proofCounterpartyChannel []byte -// proofErrorReceipt []byte -// ) +func (suite *KeeperTestSuite) TestChanUpgradeTimeout() { + var ( + path *ibctesting.Path + proofHeight exported.Height + proofCounterpartyChannel []byte + ) -// testCases := []struct { -// name string -// malleate func() -// expError error -// }{ -// // { -// // "success: proof height has passed", -// // func() {}, -// // nil, -// // }, -// { -// "success: proof timestamp has passed", -// func() { -// upgrade := path.EndpointA.GetProposedUpgrade() -// upgrade.Timeout.Height = defaultTimeoutHeight -// upgrade.Timeout.Timestamp = 5 -// suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) + testCases := []struct { + name string + malleate func() + expError error + }{ + { + "success: proof height has passed", + func() { + // force timeout as the default timeout is 1000. + // TODO: modify timeout in test to be lower. + suite.coordinator.CommitNBlocks(suite.chainB, 1000) -// suite.Require().NoError(path.EndpointA.UpdateClient()) + // ensure clients are up to date to receive valid proofs + suite.Require().NoError(path.EndpointA.UpdateClient()) -// proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) -// proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) -// }, -// nil, -// }, -// { -// "success: non-nil error receipt", -// func() { -// errReceipt = &types.ErrorReceipt{ -// Sequence: 0, -// Message: types.ErrInvalidUpgrade.Error(), -// } + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + nil, + }, + { + "success: proof timestamp has passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Height = clienttypes.ZeroHeight() + upgrade.Timeout.Timestamp = 1 + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) -// suite.chainB.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgradeErrorReceipt(suite.chainB.GetContext(), path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID, *errReceipt) + suite.Require().NoError(path.EndpointA.UpdateClient()) -// suite.Require().NoError(path.EndpointB.UpdateClient()) -// suite.Require().NoError(path.EndpointA.UpdateClient()) + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + nil, + }, + { + "channel not found", + func() { + path.EndpointA.ChannelID = ibctesting.InvalidID + }, + types.ErrChannelNotFound, + }, + { + "channel state is not in FLUSHING or FLUSHINGCOMPLETE state", + func() { + suite.Require().NoError(path.EndpointA.SetChannelState(types.OPEN)) + }, + types.ErrInvalidChannelState, + }, + { + "current upgrade not found", + func() { + suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) + }, + types.ErrUpgradeNotFound, + }, + { + "connection not found", + func() { + channel := path.EndpointA.GetChannel() + channel.ConnectionHops[0] = ibctesting.InvalidID + path.EndpointA.SetChannel(channel) + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "connection not open", + func() { + connectionEnd := path.EndpointA.GetConnection() + connectionEnd.State = connectiontypes.UNINITIALIZED + path.EndpointA.SetConnection(connectionEnd) + }, + connectiontypes.ErrInvalidConnectionState, + }, + { + "unable to retrieve timestamp at proof height", + func() { + // TODO: revert this when the upgrade timeout is not hard coded to 1000 + proofHeight = clienttypes.NewHeight(clienttypes.ParseChainID(suite.chainA.ChainID), uint64(suite.chainA.GetContext().BlockHeight())+1000) + }, + clienttypes.ErrConsensusStateNotFound, + }, + { + "invalid channel state proof", + func() { + channel := path.EndpointB.GetChannel() + channel.State = types.OPEN + path.EndpointB.SetChannel(channel) -// proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) -// proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) -// }, -// nil, -// }, -// { -// "channel not found", -// func() { -// path.EndpointA.ChannelID = ibctesting.InvalidID -// }, -// types.ErrChannelNotFound, -// }, -// { -// "channel state is not in INITUPGRADE state", -// func() { -// suite.Require().NoError(path.EndpointA.SetChannelState(types.ACKUPGRADE)) -// }, -// types.ErrInvalidChannelState, -// }, -// { -// "current upgrade not found", -// func() { -// suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) -// }, -// types.ErrUpgradeNotFound, -// }, -// { -// "connection not found", -// func() { -// channel := path.EndpointA.GetChannel() -// channel.ConnectionHops[0] = ibctesting.InvalidID -// path.EndpointA.SetChannel(channel) -// }, -// connectiontypes.ErrConnectionNotFound, -// }, -// { -// "connection not open", -// func() { -// connectionEnd := path.EndpointA.GetConnection() -// connectionEnd.State = connectiontypes.UNINITIALIZED -// path.EndpointA.SetConnection(connectionEnd) -// }, -// connectiontypes.ErrInvalidConnectionState, -// }, -// { -// "unable to retrieve timestamp at proof height", -// func() { -// proofHeight = suite.chainA.GetTimeoutHeight() -// }, -// clienttypes.ErrConsensusStateNotFound, -// }, -// { -// "timeout has not passed", -// func() { -// upgrade := path.EndpointA.GetProposedUpgrade() -// upgrade.Timeout.Height = suite.chainA.GetTimeoutHeight() -// suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) + suite.coordinator.CommitNBlocks(suite.chainB, 1000) -// suite.Require().NoError(path.EndpointA.UpdateClient()) + suite.Require().NoError(path.EndpointA.UpdateClient()) -// proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) -// proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) -// }, -// types.ErrInvalidUpgradeTimeout, -// }, -// { -// "counterparty channel state is not OPEN or INITUPGRADE (crossing hellos)", -// func() { -// channel := path.EndpointB.GetChannel() -// channel.State = types.TRYUPGRADE -// path.EndpointB.SetChannel(channel) + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// suite.Require().NoError(path.EndpointB.UpdateClient()) -// suite.Require().NoError(path.EndpointA.UpdateClient()) + // modify state so the proof becomes invalid. + channel.State = types.STATE_FLUSHING + path.EndpointB.SetChannel(channel) + suite.coordinator.CommitNBlocks(suite.chainB, 1) + }, + commitmenttypes.ErrInvalidProof, + }, + { + "invalid counterparty upgrade sequence", + func() { + channel := path.EndpointB.GetChannel() + channel.UpgradeSequence = path.EndpointA.GetChannel().UpgradeSequence - 1 + path.EndpointB.SetChannel(channel) -// proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) -// proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) -// }, -// types.ErrInvalidChannelState, -// }, -// { -// "non-nil error receipt: error receipt seq greater than current upgrade seq", -// func() { -// errReceipt = &types.ErrorReceipt{ -// Sequence: 3, -// Message: types.ErrInvalidUpgrade.Error(), -// } -// }, -// types.ErrInvalidUpgradeSequence, -// }, -// { -// "non-nil error receipt: error receipt seq equal to current upgrade seq", -// func() { -// errReceipt = &types.ErrorReceipt{ -// Sequence: 1, -// Message: types.ErrInvalidUpgrade.Error(), -// } -// }, -// types.ErrInvalidUpgradeSequence, -// }, -// } + suite.coordinator.CommitNBlocks(suite.chainB, 1000) -// for _, tc := range testCases { -// tc := tc -// suite.Run(tc.name, func() { -// suite.SetupTest() -// expPass := tc.expError == nil + suite.Require().NoError(path.EndpointA.UpdateClient()) -// path = ibctesting.NewPath(suite.chainA, suite.chainB) -// suite.coordinator.Setup(path) + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + types.ErrInvalidUpgradeSequence, + }, + { + "timeout height has not passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Height = suite.chainA.GetTimeoutHeight() + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) -// path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion -// path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + suite.Require().NoError(path.EndpointA.UpdateClient()) -// errReceipt = nil + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + types.ErrInvalidUpgradeTimeout, + }, + { + "timeout timestamp has not passed", + func() { + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Timeout.Timestamp = math.MaxUint64 + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) -// // set timeout height to 1 to ensure timeout -// path.EndpointA.ChannelConfig.ProposedUpgrade.Timeout.Height = clienttypes.NewHeight(1, 1) -// suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.UpdateClient()) -// // ensure clients are up to date to receive valid proofs -// suite.Require().NoError(path.EndpointB.UpdateClient()) -// suite.Require().NoError(path.EndpointA.UpdateClient()) + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + types.ErrInvalidUpgradeTimeout, + }, + { + "counterparty channel state is not OPEN or FLUSHING (crossing hellos)", + func() { + channel := path.EndpointB.GetChannel() + channel.State = types.STATE_FLUSHCOMPLETE + path.EndpointB.SetChannel(channel) -// proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() -// upgradeErrorReceiptKey := host.ChannelUpgradeErrorKey(path.EndpointB.ChannelConfig.PortID, path.EndpointB.ChannelID) -// proofErrorReceipt, _ = suite.chainB.QueryProof(upgradeErrorReceiptKey) + suite.coordinator.CommitNBlocks(suite.chainB, 1000) -// tc.malleate() + suite.Require().NoError(path.EndpointA.UpdateClient()) -// err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTimeout( -// suite.chainA.GetContext(), -// path.EndpointA.ChannelConfig.PortID, -// path.EndpointA.ChannelID, -// path.EndpointB.GetChannel(), -// errReceipt, -// proofCounterpartyChannel, -// proofErrorReceipt, -// proofHeight, -// ) + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + types.ErrInvalidCounterparty, + }, + { + "counterparty proposed connection invalid", + func() { + channel := path.EndpointB.GetChannel() + channel.State = types.OPEN + path.EndpointB.SetChannel(channel) -// if expPass { -// suite.Require().NoError(err) -// } else { -// suite.assertUpgradeError(err, tc.expError) -// } -// }) -// } -// } + upgrade := path.EndpointA.GetProposedUpgrade() + upgrade.Fields.ConnectionHops = []string{"connection-100"} + suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.SetUpgrade(suite.chainA.GetContext(), path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID, upgrade) + + suite.coordinator.CommitNBlocks(suite.chainB, 1000) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + suite.Require().NoError(path.EndpointB.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + connectiontypes.ErrConnectionNotFound, + }, + { + "counterparty channel already upgraded", + func() { + // put chainA channel into OPEN state since both sides are in FLUSHCOMPLETE + suite.Require().NoError(path.EndpointB.ChanUpgradeConfirm()) + + suite.coordinator.CommitNBlocks(suite.chainB, 1000) + + suite.Require().NoError(path.EndpointA.UpdateClient()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + }, + types.ErrUpgradeTimeoutFailed, + }, + } + + for _, tc := range testCases { + tc := tc + suite.Run(tc.name, func() { + suite.SetupTest() + expPass := tc.expError == nil + + path = ibctesting.NewPath(suite.chainA, suite.chainB) + suite.coordinator.Setup(path) + + path.EndpointA.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + path.EndpointB.ChannelConfig.ProposedUpgrade.Fields.Version = mock.UpgradeVersion + + suite.Require().NoError(path.EndpointA.ChanUpgradeInit()) + suite.Require().NoError(path.EndpointB.ChanUpgradeTry()) + suite.Require().NoError(path.EndpointA.ChanUpgradeAck()) + + proofCounterpartyChannel, _, proofHeight = path.EndpointA.QueryChannelUpgradeProof() + + tc.malleate() + + err := suite.chainA.GetSimApp().IBCKeeper.ChannelKeeper.ChanUpgradeTimeout( + suite.chainA.GetContext(), + path.EndpointA.ChannelConfig.PortID, + path.EndpointA.ChannelID, + path.EndpointB.GetChannel(), + proofCounterpartyChannel, + proofHeight, + ) + + if expPass { + suite.Require().NoError(err) + } else { + suite.assertUpgradeError(err, tc.expError) + } + }) + } +} func (suite *KeeperTestSuite) TestStartFlush() { var path *ibctesting.Path @@ -1809,10 +1846,10 @@ func (suite *KeeperTestSuite) assertUpgradeError(actualError, expError error) { suite.Require().True(errorsmod.IsOf(actualError, expError), fmt.Sprintf("expected error: %s, actual error: %s", expError, actualError)) } -// TestAbortHandshake tests that when the channel handshake is aborted, the channel state +// TestAbortUpgrade tests that when the channel handshake is aborted, the channel state // is restored the previous state and that an error receipt is written, and upgrade state which // is no longer required is deleted. -func (suite *KeeperTestSuite) TestAbortHandshake() { +func (suite *KeeperTestSuite) TestAbortUpgrade() { var ( path *ibctesting.Path upgradeError error @@ -1837,13 +1874,6 @@ func (suite *KeeperTestSuite) TestAbortHandshake() { }, expPass: true, }, - { - name: "upgrade does not exist", - malleate: func() { - suite.chainA.DeleteKey(host.ChannelUpgradeKey(path.EndpointA.ChannelConfig.PortID, path.EndpointA.ChannelID)) - }, - expPass: false, - }, { name: "channel does not exist", malleate: func() { diff --git a/modules/core/04-channel/types/errors.go b/modules/core/04-channel/types/errors.go index e1efa4e369e..a67d0bb9372 100644 --- a/modules/core/04-channel/types/errors.go +++ b/modules/core/04-channel/types/errors.go @@ -51,4 +51,5 @@ var ( ErrUpgradeTimeout = errorsmod.Register(SubModuleName, 35, "upgrade timed-out") ErrInvalidUpgradeTimeout = errorsmod.Register(SubModuleName, 36, "upgrade timeout is invalid") ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 37, "pending inflight packets exist") + ErrUpgradeTimeoutFailed = errorsmod.Register(SubModuleName, 38, "upgrade timeout failed") ) diff --git a/modules/core/04-channel/types/msgs.go b/modules/core/04-channel/types/msgs.go index 3d260431461..40dcb45a511 100644 --- a/modules/core/04-channel/types/msgs.go +++ b/modules/core/04-channel/types/msgs.go @@ -723,8 +723,8 @@ func (msg MsgChannelUpgradeOpen) ValidateBasic() error { return errorsmod.Wrap(commitmenttypes.ErrInvalidProof, "cannot submit an empty channel proof") } - if !collections.Contains(msg.CounterpartyChannelState, []State{TRYUPGRADE, ACKUPGRADE, OPEN}) { - return errorsmod.Wrapf(ErrInvalidChannelState, "expected channel state to be one of: %s, %s or %s, got: %s", TRYUPGRADE, ACKUPGRADE, OPEN, msg.CounterpartyChannelState) + if !collections.Contains(msg.CounterpartyChannelState, []State{STATE_FLUSHCOMPLETE, OPEN}) { + return errorsmod.Wrapf(ErrInvalidChannelState, "expected channel state to be one of: [%s, %s], got: %s", STATE_FLUSHCOMPLETE, OPEN, msg.CounterpartyChannelState) } _, err := sdk.AccAddressFromBech32(msg.Signer) @@ -782,8 +782,8 @@ func (msg MsgChannelUpgradeTimeout) ValidateBasic() error { return errorsmod.Wrap(commitmenttypes.ErrInvalidProof, "cannot submit an empty proof") } - if msg.CounterpartyChannel.State != OPEN { - return errorsmod.Wrapf(ErrInvalidChannelState, "expected: %s, got: %s", OPEN, msg.CounterpartyChannel.State) + if !collections.Contains(msg.CounterpartyChannel.State, []State{STATE_FLUSHING, OPEN}) { + return errorsmod.Wrapf(ErrInvalidChannelState, "expected counterparty channel state to be one of: [%s, %s], got: %s", STATE_FLUSHING, OPEN, msg.CounterpartyChannel.State) } _, err := sdk.AccAddressFromBech32(msg.Signer) diff --git a/modules/core/04-channel/types/msgs_test.go b/modules/core/04-channel/types/msgs_test.go index e91a7b1d7e0..b0eda4e0ab2 100644 --- a/modules/core/04-channel/types/msgs_test.go +++ b/modules/core/04-channel/types/msgs_test.go @@ -780,21 +780,14 @@ func (suite *TypesTestSuite) TestMsgChannelUpgradeOpenValidateBasic() { expPass bool }{ { - "success", + "success: flushcomplete state", func() {}, true, }, { - "success: counterparty state set to TRYUPGRADE", - func() { - msg.CounterpartyChannelState = types.TRYUPGRADE - }, - true, - }, - { - "success: counterparty state set to ACKUPGRADE", + "success: open state", func() { - msg.CounterpartyChannelState = types.ACKUPGRADE + msg.CounterpartyChannelState = types.OPEN }, true, }, @@ -840,7 +833,7 @@ func (suite *TypesTestSuite) TestMsgChannelUpgradeOpenValidateBasic() { suite.Run(tc.name, func() { msg = types.NewMsgChannelUpgradeOpen( ibctesting.MockPort, ibctesting.FirstChannelID, - types.OPEN, suite.proof, + types.STATE_FLUSHCOMPLETE, suite.proof, height, addr, ) diff --git a/modules/core/keeper/msg_server.go b/modules/core/keeper/msg_server.go index 75003d12e69..f9edfed59a4 100644 --- a/modules/core/keeper/msg_server.go +++ b/modules/core/keeper/msg_server.go @@ -764,7 +764,7 @@ func (k Keeper) ChannelUpgradeTry(goCtx context.Context, msg *channeltypes.MsgCh if err != nil { ctx.Logger().Error("channel upgrade try failed", "error", errorsmod.Wrap(err, "channel upgrade try failed")) if channeltypes.IsUpgradeError(err) { - _ = k.ChannelKeeper.WriteErrorReceipt(ctx, msg.PortId, msg.ChannelId, upgrade.Fields, err.(*channeltypes.UpgradeError)) + _ = k.ChannelKeeper.WriteErrorReceipt(ctx, msg.PortId, msg.ChannelId, err.(*channeltypes.UpgradeError)) // NOTE: a FAILURE result is returned to the client and an error receipt is written to state. // This signals to the relayer to begin the cancel upgrade handshake subprotocol. @@ -936,7 +936,7 @@ func (k Keeper) ChannelUpgradeTimeout(goCtx context.Context, msg *channeltypes.M return nil, errorsmod.Wrapf(porttypes.ErrInvalidRoute, "route not found to module: %s", module) } - err = k.ChannelKeeper.ChanUpgradeTimeout(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannel, msg.PreviousErrorReceipt, msg.ProofChannel, msg.ProofErrorReceipt, msg.ProofHeight) + err = k.ChannelKeeper.ChanUpgradeTimeout(ctx, msg.PortId, msg.ChannelId, msg.CounterpartyChannel, msg.ProofChannel, msg.ProofHeight) if err != nil { return nil, errorsmod.Wrapf(err, "could not timeout upgrade for channel: %s", msg.ChannelId) }