diff --git a/modules/apps/transfer/keeper/forwarding.go b/modules/apps/transfer/keeper/forwarding.go index c33dc58f460..1029d6369a8 100644 --- a/modules/apps/transfer/keeper/forwarding.go +++ b/modules/apps/transfer/keeper/forwarding.go @@ -44,13 +44,13 @@ func (k Keeper) forwardPacket(ctx sdk.Context, data types.FungibleTokenPacketDat } // ackForwardPacketSuccess writes a successful async acknowledgement for the prevPacket -func (k Keeper) ackForwardPacketSuccess(ctx sdk.Context, prevPacket channeltypes.Packet) error { +func (k Keeper) ackForwardPacketSuccess(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet) error { forwardAck := channeltypes.NewResultAcknowledgement([]byte{byte(1)}) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardAck) + return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) } // ackForwardPacketError reverts the receive packet logic that occurs in the middle chain and writes the async ack for the prevPacket -func (k Keeper) ackForwardPacketError(ctx sdk.Context, prevPacket channeltypes.Packet, failedPacketData types.FungibleTokenPacketDataV2) error { +func (k Keeper) ackForwardPacketError(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet, failedPacketData types.FungibleTokenPacketDataV2) error { // the forwarded packet has failed, thus the funds have been refunded to the intermediate address. // we must revert the changes that came from successfully receiving the tokens on our chain // before propagating the error acknowledgement back to original sender chain @@ -59,27 +59,32 @@ func (k Keeper) ackForwardPacketError(ctx sdk.Context, prevPacket channeltypes.P } forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketFailed) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardAck) + return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) } // ackForwardPacketTimeout reverts the receive packet logic that occurs in the middle chain and writes a failed async ack for the prevPacket -func (k Keeper) ackForwardPacketTimeout(ctx sdk.Context, prevPacket channeltypes.Packet, timeoutPacketData types.FungibleTokenPacketDataV2) error { +func (k Keeper) ackForwardPacketTimeout(ctx sdk.Context, prevPacket, forwardedPacket channeltypes.Packet, timeoutPacketData types.FungibleTokenPacketDataV2) error { if err := k.revertForwardedPacket(ctx, prevPacket, timeoutPacketData); err != nil { return err } forwardAck := channeltypes.NewErrorAcknowledgement(types.ErrForwardedPacketTimedOut) - return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardAck) + return k.acknowledgeForwardedPacket(ctx, prevPacket, forwardedPacket, forwardAck) } // acknowledgeForwardedPacket writes the async acknowledgement for packet -func (k Keeper) acknowledgeForwardedPacket(ctx sdk.Context, packet channeltypes.Packet, ack channeltypes.Acknowledgement) error { +func (k Keeper) acknowledgeForwardedPacket(ctx sdk.Context, packet, forwardedPacket channeltypes.Packet, ack channeltypes.Acknowledgement) error { capability, ok := k.scopedKeeper.GetCapability(ctx, host.ChannelCapabilityPath(packet.DestinationPort, packet.DestinationChannel)) if !ok { return errorsmod.Wrap(channeltypes.ErrChannelCapabilityNotFound, "module does not own channel capability") } - return k.ics4Wrapper.WriteAcknowledgement(ctx, capability, packet, ack) + if err := k.ics4Wrapper.WriteAcknowledgement(ctx, capability, packet, ack); err != nil { + return err + } + + k.deleteForwardedPacket(ctx, forwardedPacket.SourcePort, forwardedPacket.SourceChannel, forwardedPacket.Sequence) + return nil } // revertForwardedPacket reverts the logic of receive packet that occurs in the middle chains during a packet forwarding. diff --git a/modules/apps/transfer/keeper/keeper.go b/modules/apps/transfer/keeper/keeper.go index 2fd82e639a3..d9d8d60b546 100644 --- a/modules/apps/transfer/keeper/keeper.go +++ b/modules/apps/transfer/keeper/keeper.go @@ -328,3 +328,11 @@ func (k Keeper) GetForwardedPacket(ctx sdk.Context, portID, channelID string, se return storedPacket, true } + +// deleteForwardedPacket deletes the forwarded packet from the store. +func (k Keeper) deleteForwardedPacket(ctx sdk.Context, portID, channelID string, sequence uint64) { + store := ctx.KVStore(k.storeKey) + packetKey := types.PacketForwardKey(portID, channelID, sequence) + + store.Delete(packetKey) +} diff --git a/modules/apps/transfer/keeper/relay.go b/modules/apps/transfer/keeper/relay.go index ce65244452e..6d2d8d1e5e5 100644 --- a/modules/apps/transfer/keeper/relay.go +++ b/modules/apps/transfer/keeper/relay.go @@ -285,7 +285,7 @@ func (k Keeper) OnAcknowledgementPacket(ctx sdk.Context, packet channeltypes.Pac switch ack.Response.(type) { case *channeltypes.Acknowledgement_Result: if isForwarded { - return k.ackForwardPacketSuccess(ctx, prevPacket) + return k.ackForwardPacketSuccess(ctx, prevPacket, packet) } // the acknowledgement succeeded on the receiving chain so nothing @@ -297,7 +297,7 @@ func (k Keeper) OnAcknowledgementPacket(ctx sdk.Context, packet channeltypes.Pac return err } if isForwarded { - return k.ackForwardPacketError(ctx, prevPacket, data) + return k.ackForwardPacketError(ctx, prevPacket, packet, data) } return nil @@ -316,7 +316,7 @@ func (k Keeper) OnTimeoutPacket(ctx sdk.Context, packet channeltypes.Packet, dat prevPacket, isForwarded := k.GetForwardedPacket(ctx, packet.SourcePort, packet.SourceChannel, packet.Sequence) if isForwarded { - return k.ackForwardPacketTimeout(ctx, prevPacket, data) + return k.ackForwardPacketTimeout(ctx, prevPacket, packet, data) } return nil diff --git a/modules/apps/transfer/keeper/relay_forwarding_test.go b/modules/apps/transfer/keeper/relay_forwarding_test.go index 591321db516..b36f6e28a49 100644 --- a/modules/apps/transfer/keeper/relay_forwarding_test.go +++ b/modules/apps/transfer/keeper/relay_forwarding_test.go @@ -317,6 +317,10 @@ func (suite *KeeperTestSuite) TestSimplifiedHappyPathForwarding() { err = path2.EndpointB.UpdateClient() suite.Require().NoError(err) + // B should now have deleted the forwarded packet. + _, found := suite.chainB.GetSimApp().TransferKeeper.GetForwardedPacket(suite.chainB.GetContext(), packetFromAtoB.DestinationPort, packetFromAtoB.DestinationChannel, packetFromAtoB.Sequence) + suite.Require().False(found, "Chain B should have deleted its forwarded packet") + result, err = path2.EndpointB.RecvPacketWithResult(packetFromBtoC) suite.Require().NoError(err) suite.Require().NotNil(result) @@ -562,6 +566,10 @@ func (suite *KeeperTestSuite) TestAcknowledgementFailureScenario5Forwarding() { err = path1.EndpointB.AcknowledgePacket(packetFromBtoA, errorAckOnA.Acknowledgement()) suite.Require().NoError(err) + // Check that B deleted the forwarded packet. + _, found = suite.chainB.GetSimApp().TransferKeeper.GetForwardedPacket(suite.chainB.GetContext(), forwardedPacket.SourcePort, forwardedPacket.SourceChannel, forwardedPacket.Sequence) + suite.Require().False(found, "chain B should have deleted the forwarded packet mapping") + // Check that Escrow B has been refunded amount coin = sdk.NewCoin(denomAB.IBCDenom(), amount) totalEscrowChainB = suite.chainB.GetSimApp().TransferKeeper.GetTotalEscrowForDenom(suite.chainB.GetContext(), coin.GetDenom())