Skip to content

Commit

Permalink
refactor: restructure timeout type (#5404)
Browse files Browse the repository at this point in the history
* refactor: apply issue design to timeout type

* chore: add godocs

* tests: add tests for elapsed and ErrTimeoutElapsed functions

* rm: unfinished tests

* test: tests for ErrTimeoutNotReached

* test: add tests for ErrTimeoutNotReached

* review suggestion: switcharoo

* chore: add in-code comment

---------

Co-authored-by: Carlos Rodriguez <carlos@interchain.io>
  • Loading branch information
colin-axner and Carlos Rodriguez authored Dec 19, 2023
1 parent 26fbb37 commit 5ec221a
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 30 deletions.
6 changes: 4 additions & 2 deletions modules/core/04-channel/keeper/packet.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,12 @@ func (k Keeper) AcknowledgePacket(
counterpartyUpgrade, found := k.GetCounterpartyUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel())
if found {
timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
// packet flushing timeout has expired, abort the upgrade and return nil,
// committing an error receipt to state, restoring the channel and successfully acknowledging the packet.
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), err)
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp))
return nil
}

Expand Down
7 changes: 5 additions & 2 deletions modules/core/04-channel/keeper/timeout.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
sdk "github.com/cosmos/cosmos-sdk/types"

capabilitytypes "github.com/cosmos/ibc-go/modules/capability/types"
clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
connectiontypes "github.com/cosmos/ibc-go/v8/modules/core/03-connection/types"
"github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
host "github.com/cosmos/ibc-go/v8/modules/core/24-host"
Expand Down Expand Up @@ -157,10 +158,12 @@ func (k Keeper) TimeoutExecuted(
// then we can move to flushing complete if the timeout has not passed and there are no in-flight packets
if found {
timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
// packet flushing timeout has expired, abort the upgrade and return nil,
// committing an error receipt to state, restoring the channel and successfully timing out the packet.
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), err)
k.MustAbortUpgrade(ctx, packet.GetSourcePort(), packet.GetSourceChannel(), timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp))
} else if !k.HasInflightPackets(ctx, packet.GetSourcePort(), packet.GetSourceChannel()) {
// set the channel state to flush complete if all packets have been flushed.
channel.State = types.FLUSHCOMPLETE
Expand Down
12 changes: 8 additions & 4 deletions modules/core/04-channel/keeper/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,8 +302,10 @@ func (k Keeper) ChanUpgradeAck(
}

timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed"))
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp), "counterparty upgrade timeout elapsed"))
}

return nil
Expand Down Expand Up @@ -409,8 +411,10 @@ func (k Keeper) ChanUpgradeConfirm(
}

timeout := counterpartyUpgrade.Timeout
if hasPassed, err := timeout.HasPassed(ctx); hasPassed {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(err, "counterparty upgrade timeout has passed"))
selfHeight, selfTimestamp := clienttypes.GetSelfHeight(ctx), uint64(ctx.BlockTime().UnixNano())

if timeout.Elapsed(selfHeight, selfTimestamp) {
return types.NewUpgradeError(channel.UpgradeSequence, errorsmod.Wrap(timeout.ErrTimeoutElapsed(selfHeight, selfTimestamp), "counterparty upgrade timeout elapsed"))
}

return nil
Expand Down
4 changes: 2 additions & 2 deletions modules/core/04-channel/keeper/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -516,7 +516,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeAck() {
err = path.EndpointA.UpdateClient()
suite.Require().NoError(err)
},
types.NewUpgradeError(1, types.ErrInvalidUpgrade),
types.NewUpgradeError(1, types.ErrTimeoutElapsed),
},
}

Expand Down Expand Up @@ -813,7 +813,7 @@ func (suite *KeeperTestSuite) TestChanUpgradeConfirm() {
err := path.EndpointB.UpdateClient()
suite.Require().NoError(err)
},
types.NewUpgradeError(1, types.ErrInvalidUpgrade),
types.NewUpgradeError(1, types.ErrTimeoutElapsed),
},
}

Expand Down
2 changes: 2 additions & 0 deletions modules/core/04-channel/types/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,4 +52,6 @@ var (
ErrPendingInflightPackets = errorsmod.Register(SubModuleName, 36, "pending inflight packets exist")
ErrUpgradeTimeoutFailed = errorsmod.Register(SubModuleName, 37, "upgrade timeout failed")
ErrInvalidPruningLimit = errorsmod.Register(SubModuleName, 38, "invalid pruning limit")
ErrTimeoutNotReached = errorsmod.Register(SubModuleName, 39, "timeout not reached")
ErrTimeoutElapsed = errorsmod.Register(SubModuleName, 40, "timeout elapsed")
)
52 changes: 34 additions & 18 deletions modules/core/04-channel/types/timeout.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
package types

import (
"time"

errorsmod "cosmossdk.io/errors"

sdk "github.com/cosmos/cosmos-sdk/types"

clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
)

Expand All @@ -18,27 +14,47 @@ func NewTimeout(height clienttypes.Height, timestamp uint64) Timeout {
}
}

// IsValid returns true if either the height or timestamp is non-zero
// IsValid returns true if either the height or timestamp is non-zero.
func (t Timeout) IsValid() bool {
return !t.Height.IsZero() || t.Timestamp != 0
}

// TODO: Update after https://github.com/cosmos/ibc-go/issues/3483 has been resolved
// HasPassed returns true if the upgrade has passed the timeout height or timestamp
func (t Timeout) HasPassed(ctx sdk.Context) (bool, error) {
if !t.IsValid() {
return true, errorsmod.Wrap(ErrInvalidUpgrade, "upgrade timeout cannot be empty")
}
// Elapsed returns true if either the provided height or timestamp is past the
// respective absolute timeout values.
func (t Timeout) Elapsed(height clienttypes.Height, timestamp uint64) bool {
return t.heightElapsed(height) || t.timestampElapsed(timestamp)
}

selfHeight, timeoutHeight := clienttypes.GetSelfHeight(ctx), t.Height
if selfHeight.GTE(timeoutHeight) && timeoutHeight.GT(clienttypes.ZeroHeight()) {
return true, errorsmod.Wrapf(ErrInvalidUpgrade, "block height >= upgrade timeout height (%s >= %s)", selfHeight, timeoutHeight)
// ErrTimeoutElapsed returns a timeout elapsed error indicating which timeout value
// has elapsed.
func (t Timeout) ErrTimeoutElapsed(height clienttypes.Height, timestamp uint64) error {
if t.heightElapsed(height) {
return errorsmod.Wrapf(ErrTimeoutElapsed, "current height: %s, timeout height %s", height, t.Height)
}

selfTime, timeoutTimestamp := uint64(ctx.BlockTime().UnixNano()), t.Timestamp
if selfTime >= timeoutTimestamp && timeoutTimestamp > 0 {
return true, errorsmod.Wrapf(ErrInvalidUpgrade, "block timestamp >= upgrade timeout timestamp (%s >= %s)", ctx.BlockTime(), time.Unix(0, int64(timeoutTimestamp)))
return errorsmod.Wrapf(ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, t.Timestamp)
}

// ErrTimeoutNotReached returns a timeout not reached error indicating which timeout value
// has not been reached.
func (t Timeout) ErrTimeoutNotReached(height clienttypes.Height, timestamp uint64) error {
// only return height information if the height is set
// t.heightElapsed() will return false when it is empty
if !t.Height.IsZero() && !t.heightElapsed(height) {
return errorsmod.Wrapf(ErrTimeoutNotReached, "current height: %s, timeout height %s", height, t.Height)
}

return false, nil
return errorsmod.Wrapf(ErrTimeoutNotReached, "current timestamp: %d, timeout timestamp %d", timestamp, t.Timestamp)
}

// heightElapsed returns true if the timeout height is non empty
// and the timeout height is greater than or equal to the relative height.
func (t Timeout) heightElapsed(height clienttypes.Height) bool {
return !t.Height.IsZero() && height.GTE(t.Height)
}

// timestampElapsed returns true if the timeout timestamp is non empty
// and the timeout timestamp is greater than or equal to the relative timestamp.
func (t Timeout) timestampElapsed(timestamp uint64) bool {
return t.Timestamp != 0 && timestamp >= t.Timestamp
}
179 changes: 179 additions & 0 deletions modules/core/04-channel/types/timeout_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package types_test

import (
errorsmod "cosmossdk.io/errors"

clienttypes "github.com/cosmos/ibc-go/v8/modules/core/02-client/types"
"github.com/cosmos/ibc-go/v8/modules/core/04-channel/types"
)
Expand Down Expand Up @@ -57,3 +59,180 @@ func (suite *TypesTestSuite) TestIsValid() {
})
}
}

func (suite *TypesTestSuite) TestElapsed() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expElapsed bool
}{
{
"elapsed: both timeout with height and timestamp",
types.NewTimeout(height, timestamp),
true,
},
{
"elapsed: timeout with height and zero timestamp",
types.NewTimeout(height, 0),
true,
},
{
"elapsed: timeout with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp),
true,
},
{
"elapsed: height elapsed, timestamp did not",
types.NewTimeout(height, timestamp+1),
true,
},
{
"elapsed: timestamp elapsed, height did not",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp),
true,
},
{
"elapsed: timestamp elapsed when less than current timestamp",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp-1),
true,
},
{
"elapsed: height elapsed when less than current height",
types.NewTimeout(clienttypes.NewHeight(0, 1), 0),
true,
},
{
"not elapsed: invalid timeout",
types.NewTimeout(clienttypes.ZeroHeight(), 0),
false,
},
{
"not elapsed: neither height nor timeout elapsed",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp+1),
false,
},
{
"not elapsed: timeout not reached with height and zero timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), 0),
false,
},
{
"elapsed: timeout not reached with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp+1),
false,
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
elapsed := tc.timeout.Elapsed(height, timestamp)
suite.Require().Equal(tc.expElapsed, elapsed)
})
}
}

func (suite *TypesTestSuite) TestErrTimeoutElapsed() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expError error
}{
{
"both timeout with height and timestamp",
types.NewTimeout(height, timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timeout with height and zero timestamp",
types.NewTimeout(height, 0),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timeout with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp),
},
{
"height elapsed, timestamp did not",
types.NewTimeout(height, timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, height),
},
{
"timestamp elapsed, height did not",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp),
},
{
"height elapsed when less than current height",
types.NewTimeout(clienttypes.NewHeight(0, 1), 0),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current height: %s, timeout height %s", height, clienttypes.NewHeight(0, 1)),
},
{
"timestamp elapsed when less than current timestamp",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp-1),
errorsmod.Wrapf(types.ErrTimeoutElapsed, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp-1),
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
err := tc.timeout.ErrTimeoutElapsed(height, timestamp)
suite.Require().Equal(tc.expError.Error(), err.Error())
})
}
}

func (suite *TypesTestSuite) TestErrTimeoutNotReached() {
// elapsed is expected to be true when either timeout height or timestamp
// is greater than or equal to 2
var (
height = clienttypes.NewHeight(0, 2)
timestamp = uint64(2)
)

testCases := []struct {
name string
timeout types.Timeout
expError error
}{
{
"neither timeout reached with height and timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current height: %s, timeout height %s", height, height.Increment().(clienttypes.Height)),
},
{
"timeout not reached with height and zero timestamp",
types.NewTimeout(height.Increment().(clienttypes.Height), 0),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current height: %s, timeout height %s", height, height.Increment().(clienttypes.Height)),
},
{
"timeout not reached with timestamp and zero height",
types.NewTimeout(clienttypes.ZeroHeight(), timestamp+1),
errorsmod.Wrapf(types.ErrTimeoutNotReached, "current timestamp: %d, timeout timestamp %d", timestamp, timestamp+1),
},
}

for _, tc := range testCases {
tc := tc
suite.Run(tc.name, func() {
err := tc.timeout.ErrTimeoutNotReached(height, timestamp)
suite.Require().Equal(tc.expError.Error(), err.Error())
})
}
}
4 changes: 2 additions & 2 deletions modules/core/keeper/msg_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1326,9 +1326,9 @@ func (suite *KeeperTestSuite) TestChannelUpgradeConfirm() {
{
"core handler returns error and writes upgrade error receipt",
func() {
// force an upgrade error by modifying the counterparty channel upgrade timeout to be no longer valid
// force an upgrade error by modifying the counterparty channel upgrade timeout to be elapsed
upgrade := path.EndpointA.GetChannelUpgrade()
upgrade.Timeout = channeltypes.NewTimeout(clienttypes.ZeroHeight(), 0)
upgrade.Timeout = channeltypes.NewTimeout(clienttypes.ZeroHeight(), uint64(path.EndpointB.Chain.CurrentHeader.Time.UnixNano()))

path.EndpointA.SetChannelUpgrade(upgrade)

Expand Down

0 comments on commit 5ec221a

Please sign in to comment.