diff --git a/modules/apps/29-fee/client/cli/tx.go b/modules/apps/29-fee/client/cli/tx.go index a9f66a71007..d14c8313dca 100644 --- a/modules/apps/29-fee/client/cli/tx.go +++ b/modules/apps/29-fee/client/cli/tx.go @@ -83,9 +83,9 @@ func NewPayPacketFeeAsyncTxCmd() *cobra.Command { TimeoutFee: timeoutFee, } - identifiedPacketFee := types.NewIdentifiedPacketFee(packetID, fee, sender, relayers) + packetFee := types.NewPacketFee(fee, sender, relayers) + msg := types.NewMsgPayPacketFeeAsync(packetID, packetFee) - msg := types.NewMsgPayPacketFeeAsync(identifiedPacketFee) return tx.GenerateOrBroadcastTxCLI(clientCtx, cmd.Flags(), msg) }, } diff --git a/modules/apps/29-fee/keeper/msg_server.go b/modules/apps/29-fee/keeper/msg_server.go index 2f0004be150..fdac1d27874 100644 --- a/modules/apps/29-fee/keeper/msg_server.go +++ b/modules/apps/29-fee/keeper/msg_server.go @@ -57,9 +57,7 @@ func (k Keeper) PayPacketFee(goCtx context.Context, msg *types.MsgPayPacketFee) func (k Keeper) PayPacketFeeAsync(goCtx context.Context, msg *types.MsgPayPacketFeeAsync) (*types.MsgPayPacketFeeAsyncResponse, error) { ctx := sdk.UnwrapSDKContext(goCtx) - // TODO: Update MsgPayPacketFeeAsync to include PacketFee in favour of IdentifiedPacketFee - packetFee := types.NewPacketFee(msg.IdentifiedPacketFee.Fee, msg.IdentifiedPacketFee.RefundAddress, msg.IdentifiedPacketFee.Relayers) - if err := k.EscrowPacketFee(ctx, msg.IdentifiedPacketFee.PacketId, packetFee); err != nil { + if err := k.EscrowPacketFee(ctx, msg.PacketId, msg.PacketFee); err != nil { return nil, err } diff --git a/modules/apps/29-fee/keeper/msg_server_test.go b/modules/apps/29-fee/keeper/msg_server_test.go index 0a8ad4b7d06..c6dc2f90856 100644 --- a/modules/apps/29-fee/keeper/msg_server_test.go +++ b/modules/apps/29-fee/keeper/msg_server_test.go @@ -119,12 +119,12 @@ func (suite *KeeperTestSuite) TestPayPacketFeeAsync() { seq, _ := suite.chainA.App.GetIBCKeeper().ChannelKeeper.GetNextSequenceSend(ctxA, suite.path.EndpointA.ChannelConfig.PortID, suite.path.EndpointA.ChannelID) // build fee - packetId := channeltypes.NewPacketId(channelID, suite.path.EndpointA.ChannelConfig.PortID, seq) - identifiedPacketFee := types.IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: refundAcc.String(), Relayers: []string{}} + packetID := channeltypes.NewPacketId(channelID, suite.path.EndpointA.ChannelConfig.PortID, seq) + packetFee := types.NewPacketFee(fee, refundAcc.String(), nil) tc.malleate() - msg := types.NewMsgPayPacketFeeAsync(identifiedPacketFee) + msg := types.NewMsgPayPacketFeeAsync(packetID, packetFee) _, err := suite.chainA.SendMsgs(msg) if tc.expPass { diff --git a/modules/apps/29-fee/types/fee.go b/modules/apps/29-fee/types/fee.go index 54dd9ee1435..4dd73893888 100644 --- a/modules/apps/29-fee/types/fee.go +++ b/modules/apps/29-fee/types/fee.go @@ -16,6 +16,25 @@ func NewPacketFee(fee Fee, refundAddr string, relayers []string) PacketFee { } } +// Validate performs basic stateless validation of the associated PacketFee +func (p PacketFee) Validate() error { + _, err := sdk.AccAddressFromBech32(p.RefundAddress) + if err != nil { + return sdkerrors.Wrap(err, "failed to convert RefundAddress into sdk.AccAddress") + } + + // enforce relayer is nil + if p.Relayers != nil { + return ErrRelayersNotNil + } + + if err := p.Fee.Validate(); err != nil { + return err + } + + return nil +} + // NewPacketFees creates and returns a new PacketFees struct including a list of type PacketFee func NewPacketFees(packetFees []PacketFee) PacketFees { return PacketFees{ diff --git a/modules/apps/29-fee/types/msgs.go b/modules/apps/29-fee/types/msgs.go index 50e6584e26b..2c0d8a52e32 100644 --- a/modules/apps/29-fee/types/msgs.go +++ b/modules/apps/29-fee/types/msgs.go @@ -118,22 +118,21 @@ func (msg MsgPayPacketFee) GetSignBytes() []byte { } // NewMsgPayPacketAsync creates a new instance of MsgPayPacketFee -func NewMsgPayPacketFeeAsync(identifiedPacketFee IdentifiedPacketFee) *MsgPayPacketFeeAsync { +func NewMsgPayPacketFeeAsync(packetID channeltypes.PacketId, packetFee PacketFee) *MsgPayPacketFeeAsync { return &MsgPayPacketFeeAsync{ - IdentifiedPacketFee: identifiedPacketFee, + PacketId: packetID, + PacketFee: packetFee, } } // ValidateBasic performs a basic check of the MsgPayPacketFeeAsync fields func (msg MsgPayPacketFeeAsync) ValidateBasic() error { - // signer check - _, err := sdk.AccAddressFromBech32(msg.IdentifiedPacketFee.RefundAddress) - if err != nil { - return sdkerrors.Wrap(err, "failed to convert msg.Signer into sdk.AccAddress") + if err := msg.PacketId.Validate(); err != nil { + return err } - if err = msg.IdentifiedPacketFee.Validate(); err != nil { - return sdkerrors.Wrap(err, "Invalid IdentifiedPacketFee") + if err := msg.PacketFee.Validate(); err != nil { + return err } return nil @@ -142,7 +141,7 @@ func (msg MsgPayPacketFeeAsync) ValidateBasic() error { // GetSigners implements sdk.Msg // The signer of the fee message must be the refund address func (msg MsgPayPacketFeeAsync) GetSigners() []sdk.AccAddress { - signer, err := sdk.AccAddressFromBech32(msg.IdentifiedPacketFee.RefundAddress) + signer, err := sdk.AccAddressFromBech32(msg.PacketFee.RefundAddress) if err != nil { panic(err) } diff --git a/modules/apps/29-fee/types/msgs_test.go b/modules/apps/29-fee/types/msgs_test.go index c4ee1a477f1..6ded61d407f 100644 --- a/modules/apps/29-fee/types/msgs_test.go +++ b/modules/apps/29-fee/types/msgs_test.go @@ -319,9 +319,10 @@ func TestMsgPayPacketFeeAsyncValidation(t *testing.T) { tc.malleate() fee = Fee{receiveFee, ackFee, timeoutFee} - packetId := channeltypes.NewPacketId(channelID, portID, seq) - identifiedPacketFee := IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: signer, Relayers: relayers} - msg := NewMsgPayPacketFeeAsync(identifiedPacketFee) + packetID := channeltypes.NewPacketId(channelID, portID, seq) + packetFee := NewPacketFee(fee, signer, relayers) + + msg := NewMsgPayPacketFeeAsync(packetID, packetFee) err := msg.ValidateBasic() @@ -335,20 +336,15 @@ func TestMsgPayPacketFeeAsyncValidation(t *testing.T) { // TestRegisterCounterpartyAddressGetSigners tests GetSigners func TestPayPacketFeeAsyncGetSigners(t *testing.T) { - addr := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()) - // build message - channelID := validChannelID - portID := validPortID - fee := Fee{validCoins, validCoins, validCoins} - seq := uint64(1) - packetId := channeltypes.NewPacketId(channelID, portID, seq) - identifiedPacketFee := IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: addr.String(), Relayers: nil} - msg := NewMsgPayPacketFeeAsync(identifiedPacketFee) + refundAddr := sdk.AccAddress(secp256k1.GenPrivKey().PubKey().Address()) + fee := NewFee(validCoins, validCoins, validCoins) - // GetSigners - res := msg.GetSigners() + packetID := channeltypes.NewPacketId(validChannelID, validPortID, 1) + packetFee := NewPacketFee(fee, refundAddr.String(), nil) - require.Equal(t, []sdk.AccAddress{addr}, res) + msg := NewMsgPayPacketFeeAsync(packetID, packetFee) + + require.Equal(t, []sdk.AccAddress{refundAddr}, msg.GetSigners()) } // TestMsgPayPacketFeeAsyncRoute tests Route for MsgPayPacketFeeAsync @@ -360,9 +356,10 @@ func TestMsgPayPacketFeeAsyncRoute(t *testing.T) { portID := validPortID fee := Fee{validCoins, validCoins, validCoins} seq := uint64(1) - packetId := channeltypes.NewPacketId(channelID, portID, seq) - identifiedPacketFee := IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: addr.String(), Relayers: nil} - msg := NewMsgPayPacketFeeAsync(identifiedPacketFee) + packetID := channeltypes.NewPacketId(channelID, portID, seq) + packetFee := NewPacketFee(fee, addr.String(), nil) + + msg := NewMsgPayPacketFeeAsync(packetID, packetFee) require.Equal(t, RouterKey, msg.Route()) } @@ -376,9 +373,10 @@ func TestMsgPayPacketFeeAsyncType(t *testing.T) { portID := validPortID fee := Fee{validCoins, validCoins, validCoins} seq := uint64(1) - packetId := channeltypes.NewPacketId(channelID, portID, seq) - identifiedPacketFee := IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: addr.String(), Relayers: nil} - msg := NewMsgPayPacketFeeAsync(identifiedPacketFee) + packetID := channeltypes.NewPacketId(channelID, portID, seq) + packetFee := NewPacketFee(fee, addr.String(), nil) + + msg := NewMsgPayPacketFeeAsync(packetID, packetFee) require.Equal(t, "payPacketFeeAsync", msg.Type()) } @@ -392,9 +390,10 @@ func TestMsgPayPacketFeeAsyncGetSignBytes(t *testing.T) { portID := validPortID fee := Fee{validCoins, validCoins, validCoins} seq := uint64(1) - packetId := channeltypes.NewPacketId(channelID, portID, seq) - identifiedPacketFee := IdentifiedPacketFee{PacketId: packetId, Fee: fee, RefundAddress: addr.String(), Relayers: nil} - msg := NewMsgPayPacketFeeAsync(identifiedPacketFee) + packetID := channeltypes.NewPacketId(channelID, portID, seq) + packetFee := NewPacketFee(fee, addr.String(), nil) + + msg := NewMsgPayPacketFeeAsync(packetID, packetFee) require.NotPanics(t, func() { _ = msg.GetSignBytes()