Skip to content

Commit

Permalink
(chore) Refactor code around forwarding validation (#6706)
Browse files Browse the repository at this point in the history
* Refactor validation

* Fixed verification logic, added two tests

* Fix check for unwind

* removed unneeded indirection

* Update modules/apps/transfer/types/msgs.go

Co-authored-by: DimitrisJim <d.f.hilliard@gmail.com>

* Add docstring.

---------

Co-authored-by: Gjermund Garaba <bjaanes@gmail.com>
Co-authored-by: DimitrisJim <d.f.hilliard@gmail.com>
  • Loading branch information
3 people authored Jun 26, 2024
1 parent 6e1a082 commit 8f9691f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 35 deletions.
72 changes: 37 additions & 35 deletions modules/apps/transfer/types/msgs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,22 @@ func NewMsgTransfer(
// NOTE: The recipient addresses format is not validated as the format defined by
// the chain is not known to IBC.
func (msg MsgTransfer) ValidateBasic() error {
if err := validateSourcePortAndChannel(msg); err != nil {
return err // The actual error and its message are already wrapped in the called function.
if err := msg.validateForwarding(); err != nil {
return err
}
if !msg.Forwarding.Unwind {
// We verify that portID and channelID are valid IDs only if
// we are not setting unwind to true.
// In that case, validation that they are empty is performed in
// validateForwarding().
if err := host.PortIdentifierValidator(msg.SourcePort); err != nil {
return errorsmod.Wrap(err, "invalid source port ID")
}

if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil {
return errorsmod.Wrap(err, "invalid source channel ID")
}
}
if len(msg.Tokens) == 0 && !isValidIBCCoin(msg.Token) {
return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "either token or token array must be filled")
}
Expand All @@ -99,30 +111,42 @@ func (msg MsgTransfer) ValidateBasic() error {
return errorsmod.Wrapf(ErrInvalidMemo, "memo must not exceed %d bytes", MaximumMemoLength)
}

for _, coin := range msg.GetCoins() {
if err := validateIBCCoin(coin); err != nil {
return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String())
}
}

return nil
}

// validateForwarding ensures that forwarding is set up correctly.
func (msg MsgTransfer) validateForwarding() error {
if !msg.ShouldBeForwarded() {
return nil
}
if err := msg.Forwarding.Validate(); err != nil {
return err
}

if msg.ShouldBeForwarded() {
if !msg.TimeoutHeight.IsZero() {
// when forwarding, the timeout height must not be set
if !msg.TimeoutHeight.IsZero() {
return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must not be set if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops)
}
return errorsmod.Wrapf(ErrInvalidPacketTimeout, "timeout height must be zero if forwarding path hops is not empty: %s, %s", msg.TimeoutHeight, msg.Forwarding.Hops)
}

if msg.Forwarding.Unwind {
// When unwinding, we must have at most one token.
if msg.SourcePort != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort)
}
if msg.SourceChannel != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel)
}
if len(msg.GetCoins()) > 1 {
// When unwinding, we must have at most one token.
return errorsmod.Wrap(ibcerrors.ErrInvalidCoins, "cannot unwind more than one token")
}
}

for _, coin := range msg.GetCoins() {
if err := validateIBCCoin(coin); err != nil {
return errorsmod.Wrapf(ibcerrors.ErrInvalidCoins, "%s: %s", err.Error(), coin.String())
}
}

return nil
}

Expand Down Expand Up @@ -164,25 +188,3 @@ func validateIBCCoin(coin sdk.Coin) error {

return nil
}

func validateSourcePortAndChannel(msg MsgTransfer) error {
// If unwind is set, we want to ensure that port and channel are empty.
if msg.Forwarding.Unwind {
if msg.SourcePort != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source port must be empty when unwind is set, got %s instead", msg.SourcePort)
}
if msg.SourceChannel != "" {
return errorsmod.Wrapf(ErrInvalidForwarding, "source channel must be empty when unwind is set, got %s instead", msg.SourceChannel)
}
return nil
}

// Otherwise, we just do the usual validation of the port and channel identifiers.
if err := host.PortIdentifierValidator(msg.SourcePort); err != nil {
return errorsmod.Wrap(err, "invalid source port ID")
}
if err := host.ChannelIdentifierValidator(msg.SourceChannel); err != nil {
return errorsmod.Wrap(err, "invalid source channel ID")
}
return nil
}
2 changes: 2 additions & 0 deletions modules/apps/transfer/types/msgs_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ func TestMsgTransferValidation(t *testing.T) {
{"invalid forwarding info port", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: invalidPort, ChannelId: validChannel})), types.ErrInvalidForwarding},
{"invalid forwarding info channel", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, types.Hop{PortId: validPort, ChannelId: invalidChannel})), types.ErrInvalidForwarding},
{"invalid forwarding info too many hops", types.NewMsgTransfer(validPort, validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, generateHops(types.MaximumNumberOfForwardingHops+1)...)), types.ErrInvalidForwarding},
{"invalid portID when forwarding is set but unwind is not", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID},
{"invalid channelID when forwarding is set but unwind is not", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(false, validHop)), host.ErrInvalidID},
{"unwind specified but source port is not empty", types.NewMsgTransfer(validPort, "", coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding},
{"unwind specified but source channel is not empty", types.NewMsgTransfer("", validChannel, coins, sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), types.ErrInvalidForwarding},
{"unwind specified but more than one coin in the message", types.NewMsgTransfer("", "", coins.Add(sdk.NewCoin("atom", ibctesting.TestCoin.Amount)), sender, receiver, clienttypes.ZeroHeight(), 100, "", types.NewForwarding(true)), ibcerrors.ErrInvalidCoins},
Expand Down

0 comments on commit 8f9691f

Please sign in to comment.