diff --git a/modules/apps/transfer/types/msgs.go b/modules/apps/transfer/types/msgs.go index eafb6abd59e..be3420d233f 100644 --- a/modules/apps/transfer/types/msgs.go +++ b/modules/apps/transfer/types/msgs.go @@ -69,10 +69,22 @@ func NewMsgTransfer( // NOTE: The recipient addresses format is not validated as the format defined by // the chain is not known to IBC. func (msg MsgTransfer) ValidateBasic() error { - if err := validateSourcePortAndChannel(msg); err != nil { - return err // The actual error and its message are already wrapped in the called function. + if err := msg.validateForwarding(); err != nil { + return err } + if !msg.Forwarding.Unwind { + // We verify that portID and channelID are valid IDs only if + // we are not setting unwind to true. + // In that case, validation that they are empty is performed in + // validateForwarding(). + if err := host.PortIdentifierValidator(msg.SourcePort); err != nil { + return errorsmod.Wrap(err, "invalid source port ID") + } + if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil { + return errorsmod.Wrap(err, "invalid source channel ID") + } + } if len(msg.Tokens) == 0 && !isValidIBCCoin(msg.Token) { return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "either token or token array must be filled") } @@ -99,30 +111,42 @@ func (msg MsgTransfer) ValidateBasic() error { return errorsmod.Wrapf(ErrInvalidMemo, "memo must not exceed %d bytes", MaximumMemoLength) } + for _, coin := range msg.GetCoins() { + if err := validateIBCCoin(coin); err != nil { + return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String()) + } + } + + return nil +} + +// validateForwarding ensures that forwarding is set up correctly. +func (msg MsgTransfer) validateForwarding() error { + if !msg.ShouldBeForwarded() { + return nil + } if err := msg.Forwarding.Validate(); err != nil { return err } - if msg.ShouldBeForwarded() { + if !msg.TimeoutHeight.IsZero() { // when forwarding, the timeout height must not be set - if !msg.TimeoutHeight.IsZero() { - return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must not be set if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops) - } + return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must be zero if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops) } if msg.Forwarding.Unwind { - // When unwinding, we must have at most one token. + if msg.SourcePort != "" { + return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort) + } + if msg.SourceChannel != "" { + return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel) + } if len(msg.GetCoins()) > 1 { + // When unwinding, we must have at most one token. return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "cannot unwind more than one token") } } - for _, coin := range msg.GetCoins() { - if err := validateIBCCoin(coin); err != nil { - return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String()) - } - } - return nil } @@ -164,25 +188,3 @@ func validateIBCCoin(coin sdk.Coin) error { return nil } - -func validateSourcePortAndChannel(msg MsgTransfer) error { - // If unwind is set, we want to ensure that port and channel are empty. - if msg.Forwarding.Unwind { - if msg.SourcePort != "" { - return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort) - } - if msg.SourceChannel != "" { - return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel) - } - return nil - } - - // Otherwise, we just do the usual validation of the port and channel identifiers. - if err := host.PortIdentifierValidator(msg.SourcePort); err != nil { - return errorsmod.Wrap(err, "invalid source port ID") - } - if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil { - return errorsmod.Wrap(err, "invalid source channel ID") - } - return nil -} diff --git a/modules/apps/transfer/types/msgs_test.go b/modules/apps/transfer/types/msgs_test.go index b0ccbb611bf..2fea4f4a5b2 100644 --- a/modules/apps/transfer/types/msgs_test.go +++ b/modules/apps/transfer/types/msgs_test.go @@ -87,6 +87,8 @@ func TestMsgTransferValidation(t *testing.T) { {"invalid forwarding info port", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: invalidPort, ChannelId: validChannel})), types.ErrInvalidForwarding}, {"invalid forwarding info channel", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: validPort, ChannelId: invalidChannel})), types.ErrInvalidForwarding}, {"invalid forwarding info too many hops", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, generateHops(types.MaximumNumberOfForwardingHops+1)...)), types.ErrInvalidForwarding}, + {"invalid portID when forwarding is set but unwind is not", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID}, + {"invalid channelID when forwarding is set but unwind is not", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID}, {"unwind specified but source port is not empty", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding}, {"unwind specified but source channel is not empty", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding}, {"unwind specified but more than one coin in the message", types.NewMsgTransfer("", "", coins.Add(sdk.NewCoin("atom", ibctesting.TestCoin.Amount)), sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), ibcerrors.ErrInvalidCoins},