Skip to content

Commit

Permalink
chore: refactor UpgradeError to use built in errors functions (#5704) (
Browse files Browse the repository at this point in the history
…#5715)

(cherry picked from commit 97ea045)

Co-authored-by: Cian Hatton <cian@interchain.io>
  • Loading branch information
mergify[bot] and chatton authored Jan 24, 2024
1 parent 64897d9 commit 1718cbe
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 30 deletions.
33 changes: 13 additions & 20 deletions modules/core/04-channel/types/upgrade.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,13 +83,21 @@ func (u *UpgradeError) Error() string {
return u.err.Error()
}

// Is returns true if the underlying error is of the given err type.
func (u *UpgradeError) Is(err error) bool {
return errors.Is(u.err, err)
// Is returns true if the of the provided error is an upgrade error.
func (*UpgradeError) Is(err error) bool {
_, ok := err.(*UpgradeError)
return ok
}

// Unwrap returns the base error that caused the upgrade to fail.
// Unwrap returns the next error in the error chain.
// If there is no next error, Unwrap returns nil.
func (u *UpgradeError) Unwrap() error {
return u.err
}

// Cause implements the sdk error interface which uses this function to unwrap the error in various functions such as `wrappedError.Is()`.
// Cause returns the underlying error which caused the upgrade to fail.
func (u *UpgradeError) Cause() error {
baseError := u.err
for {
if err := errors.Unwrap(baseError); err != nil {
Expand All @@ -100,12 +108,6 @@ func (u *UpgradeError) Unwrap() error {
}
}

// Cause implements the sdk error interface which uses this function to unwrap the error in various functions such as `wrappedError.Is()`.
// Cause returns the underlying error which caused the upgrade to fail.
func (u *UpgradeError) Cause() error {
return u.err
}

// GetErrorReceipt returns an error receipt with the code from the underlying error type stripped.
func (u *UpgradeError) GetErrorReceipt() ErrorReceipt {
// restoreErrorString defines a string constant included in error receipts.
Expand All @@ -122,14 +124,5 @@ func (u *UpgradeError) GetErrorReceipt() ErrorReceipt {
// IsUpgradeError returns true if err is of type UpgradeError or contained
// in the error chain of err and false otherwise.
func IsUpgradeError(err error) bool {
for {
_, ok := err.(*UpgradeError)
if ok {
return true
}

if err = errors.Unwrap(err); err == nil {
return false
}
}
return errors.Is(err, &UpgradeError{})
}
35 changes: 25 additions & 10 deletions modules/core/04-channel/types/upgrade_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,18 +159,33 @@ func (suite *TypesTestSuite) TestGetErrorReceipt() {
suite.Require().Equal(upgradeError2.GetErrorReceipt().Message, upgradeError.GetErrorReceipt().Message)
}

// TestUpgradeErrorUnwrap tests that the underlying error is not modified when Unwrap is called.
// TestUpgradeErrorUnwrap tests that the underlying error is returned by Unwrap.
func (suite *TypesTestSuite) TestUpgradeErrorUnwrap() {
baseUnderlyingError := errorsmod.Wrap(types.ErrInvalidChannel, "base error")
wrappedErr := errorsmod.Wrap(baseUnderlyingError, "wrapped error")
upgradeError := types.NewUpgradeError(1, wrappedErr)

originalUpgradeError := upgradeError.Error()
unWrapped := errors.Unwrap(upgradeError)
postUnwrapUpgradeError := upgradeError.Error()
testCases := []struct {
msg string
upgradeError *types.UpgradeError
expError error
}{
{
msg: "no underlying error",
upgradeError: types.NewUpgradeError(1, nil),
expError: nil,
},
{
msg: "underlying error",
upgradeError: types.NewUpgradeError(1, types.ErrInvalidUpgrade),
expError: types.ErrInvalidUpgrade,
},
}

suite.Require().Equal(types.ErrInvalidChannel, unWrapped, "unwrapped error was not equal to base underlying error")
suite.Require().Equal(originalUpgradeError, postUnwrapUpgradeError, "original error was modified when unwrapped")
for _, tc := range testCases {
tc := tc
suite.Run(tc.msg, func() {
upgradeError := tc.upgradeError
err := upgradeError.Unwrap()
suite.Require().Equal(tc.expError, err)
})
}
}

func (suite *TypesTestSuite) TestIsUpgradeError() {
Expand Down

0 comments on commit 1718cbe

Please sign in to comment.