Skip to content

Commit

Permalink
[CT-1196] move replay protectin out of sigverify (#2257)
Browse files Browse the repository at this point in the history
  • Loading branch information
jayy04 authored Sep 20, 2024
1 parent 63802c6 commit 900984e
Show file tree
Hide file tree
Showing 4 changed files with 141 additions and 56 deletions.
12 changes: 11 additions & 1 deletion protocol/app/ante.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,9 +114,12 @@ func NewAnteHandler(options HandlerOptions) (sdk.AnteHandler, error) {
validateBasic: ante.NewValidateBasicDecorator(),
validateSigCount: ante.NewValidateSigCountDecorator(options.AccountKeeper),
incrementSequence: ante.NewIncrementSequenceDecorator(options.AccountKeeper),
sigVerification: customante.NewSigVerificationDecorator(
replayProtection: customante.NewReplayProtectionDecorator(
options.AccountKeeper,
*options.AccountplusKeeper,
),
sigVerification: customante.NewSigVerificationDecorator(
options.AccountKeeper,
options.SignModeHandler,
),
consumeTxSizeGas: ante.NewConsumeGasForTxSizeDecorator(options.AccountKeeper),
Expand Down Expand Up @@ -155,6 +158,7 @@ type lockingAnteHandler struct {
validateBasic ante.ValidateBasicDecorator
validateSigCount ante.ValidateSigCountDecorator
incrementSequence ante.IncrementSequenceDecorator
replayProtection customante.ReplayProtectionDecorator
sigVerification customante.SigVerificationDecorator
consumeTxSizeGas ante.ConsumeTxSizeGasDecorator
deductFee ante.DeductFeeDecorator
Expand Down Expand Up @@ -244,6 +248,9 @@ func (h *lockingAnteHandler) clobAnteHandle(ctx sdk.Context, tx sdk.Tx, simulate
if ctx, err = h.sigGasConsume.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
if ctx, err = h.replayProtection.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
if ctx, err = h.sigVerification.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
Expand Down Expand Up @@ -411,6 +418,9 @@ func (h *lockingAnteHandler) otherMsgAnteHandle(ctx sdk.Context, tx sdk.Tx, simu
if ctx, err = h.sigGasConsume.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
if ctx, err = h.replayProtection.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
if ctx, err = h.sigVerification.AnteHandle(ctx, tx, simulate, noOpAnteHandle); err != nil {
return ctx, err
}
Expand Down
120 changes: 120 additions & 0 deletions protocol/app/ante/replay_protection.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
package ante

import (
"fmt"

errorsmod "cosmossdk.io/errors"
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
sdkante "github.com/cosmos/cosmos-sdk/x/auth/ante"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/dydxprotocol/v4-chain/protocol/lib/metrics"
accountpluskeeper "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/keeper"
gometrics "github.com/hashicorp/go-metrics"
)

type ReplayProtectionDecorator struct {
ak sdkante.AccountKeeper
akp accountpluskeeper.Keeper
}

func NewReplayProtectionDecorator(
ak sdkante.AccountKeeper,
akp accountpluskeeper.Keeper,
) ReplayProtectionDecorator {
return ReplayProtectionDecorator{
ak: ak,
akp: akp,
}
}

func (rpd ReplayProtectionDecorator) AnteHandle(
ctx sdk.Context,
tx sdk.Tx,
simulate bool,
next sdk.AnteHandler,
) (newCtx sdk.Context, err error) {
sigTx, ok := tx.(authsigning.Tx)
if !ok {
return ctx, errorsmod.Wrap(sdkerrors.ErrTxDecode, "invalid transaction type")
}

// stdSigs contains the sequence number, account number, and signatures.
// When simulating, this would just be a 0-length slice.
sigs, err := sigTx.GetSignaturesV2()
if err != nil {
return ctx, err
}

signers, err := sigTx.GetSigners()
if err != nil {
return ctx, err
}

// Check that signer length and signature length are the same.
// The ordering of the sigs and signers have matching ordering (sigs[i] belongs to signers[i]).
if len(sigs) != len(signers) {
err := errorsmod.Wrapf(
sdkerrors.ErrUnauthorized,
"invalid number of signer; expected: %d, got %d",
len(signers),
len(sigs),
)
return ctx, err
}

// Sequence number validation can be skipped if the given transaction consists of
// only messages that use `GoodTilBlock` for replay protection.
skipSequenceValidation := ShouldSkipSequenceValidation(tx.GetMsgs())

if !skipSequenceValidation {
// Iterate on sig and signer pairs.
for i, sig := range sigs {
acc, err := sdkante.GetSignerAcc(ctx, rpd.ak, signers[i])
if err != nil {
return ctx, err
}

// Check account sequence number.
// Skip individual sequence number validation since this transaction use
// `GoodTilBlock` for replay protection.
if accountpluskeeper.IsTimestampNonce(sig.Sequence) {
if err := rpd.akp.ProcessTimestampNonce(ctx, acc, sig.Sequence); err != nil {
telemetry.IncrCounterWithLabels(
[]string{metrics.TimestampNonce, metrics.Invalid, metrics.Count},
1,
[]gometrics.Label{metrics.GetLabelForIntValue(metrics.ExecMode, int(ctx.ExecMode()))},
)
return ctx, errorsmod.Wrapf(sdkerrors.ErrWrongSequence, err.Error())
}
telemetry.IncrCounterWithLabels(
[]string{metrics.TimestampNonce, metrics.Valid, metrics.Count},
1,
[]gometrics.Label{metrics.GetLabelForIntValue(metrics.ExecMode, int(ctx.ExecMode()))},
)
} else {
if sig.Sequence != acc.GetSequence() {
labels := make([]gometrics.Label, 0)
if len(tx.GetMsgs()) > 0 {
labels = append(
labels,
metrics.GetLabelForStringValue(metrics.MessageType, fmt.Sprintf("%T", tx.GetMsgs()[0])),
)
}
telemetry.IncrCounterWithLabels(
[]string{metrics.SequenceNumber, metrics.Invalid, metrics.Count},
1,
labels,
)
return ctx, errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
}
}
}

return next(ctx, tx, simulate)
}
51 changes: 0 additions & 51 deletions protocol/app/ante/sigverify.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@ import (
errorsmod "cosmossdk.io/errors"
txsigning "cosmossdk.io/x/tx/signing"
codectypes "github.com/cosmos/cosmos-sdk/codec/types"
"github.com/cosmos/cosmos-sdk/telemetry"
sdk "github.com/cosmos/cosmos-sdk/types"
sdkerrors "github.com/cosmos/cosmos-sdk/types/errors"
sdkante "github.com/cosmos/cosmos-sdk/x/auth/ante"
authsigning "github.com/cosmos/cosmos-sdk/x/auth/signing"
"github.com/dydxprotocol/v4-chain/protocol/lib/metrics"
accountpluskeeper "github.com/dydxprotocol/v4-chain/protocol/x/accountplus/keeper"
gometrics "github.com/hashicorp/go-metrics"
"google.golang.org/protobuf/types/known/anypb"
)

Expand All @@ -24,18 +20,15 @@ import (
// CONTRACT: Tx must implement SigVerifiableTx interface
type SigVerificationDecorator struct {
ak sdkante.AccountKeeper
akp accountpluskeeper.Keeper
signModeHandler *txsigning.HandlerMap
}

func NewSigVerificationDecorator(
ak sdkante.AccountKeeper,
akp accountpluskeeper.Keeper,
signModeHandler *txsigning.HandlerMap,
) SigVerificationDecorator {
return SigVerificationDecorator{
ak: ak,
akp: akp,
signModeHandler: signModeHandler,
}
}
Expand Down Expand Up @@ -75,10 +68,6 @@ func (svd SigVerificationDecorator) AnteHandle(
return ctx, err
}

// Sequence number validation can be skipped if the given transaction consists of
// only messages that use `GoodTilBlock` for replay protection.
skipSequenceValidation := ShouldSkipSequenceValidation(tx.GetMsgs())

// Iterate on sig and signer pairs.
for i, sig := range sigs {
acc, err := sdkante.GetSignerAcc(ctx, svd.ak, signers[i])
Expand All @@ -92,46 +81,6 @@ func (svd SigVerificationDecorator) AnteHandle(
return ctx, errorsmod.Wrap(sdkerrors.ErrInvalidPubKey, "pubkey on account is not set")
}

// Check account sequence number.
// Skip individual sequence number validation since this transaction use
// `GoodTilBlock` for replay protection.
if !skipSequenceValidation {
if accountpluskeeper.IsTimestampNonce(sig.Sequence) {
if err := svd.akp.ProcessTimestampNonce(ctx, acc, sig.Sequence); err != nil {
telemetry.IncrCounterWithLabels(
[]string{metrics.TimestampNonce, metrics.Invalid, metrics.Count},
1,
[]gometrics.Label{metrics.GetLabelForIntValue(metrics.ExecMode, int(ctx.ExecMode()))},
)
return ctx, errorsmod.Wrapf(sdkerrors.ErrWrongSequence, err.Error())
}
telemetry.IncrCounterWithLabels(
[]string{metrics.TimestampNonce, metrics.Valid, metrics.Count},
1,
[]gometrics.Label{metrics.GetLabelForIntValue(metrics.ExecMode, int(ctx.ExecMode()))},
)
} else {
if sig.Sequence != acc.GetSequence() {
labels := make([]gometrics.Label, 0)
if len(tx.GetMsgs()) > 0 {
labels = append(
labels,
metrics.GetLabelForStringValue(metrics.MessageType, fmt.Sprintf("%T", tx.GetMsgs()[0])),
)
}
telemetry.IncrCounterWithLabels(
[]string{metrics.SequenceNumber, metrics.Invalid, metrics.Count},
1,
labels,
)
return ctx, errorsmod.Wrapf(
sdkerrors.ErrWrongSequence,
"account sequence mismatch, expected %d, got %d", acc.GetSequence(), sig.Sequence,
)
}
}
}

// retrieve signer data
genesis := ctx.BlockHeight() == 0
chainID := ctx.ChainID()
Expand Down
14 changes: 10 additions & 4 deletions protocol/app/ante/sigverify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,12 +104,15 @@ func TestSigVerification(t *testing.T) {
txConfigOpts,
)
require.NoError(t, err)
svd := customante.NewSigVerificationDecorator(
rpd := customante.NewReplayProtectionDecorator(
suite.AccountKeeper,
suite.AccountplusKeeper,
)
svd := customante.NewSigVerificationDecorator(
suite.AccountKeeper,
anteTxConfig.SignModeHandler(),
)
antehandler := sdk.ChainAnteDecorators(spkd, svd)
antehandler := sdk.ChainAnteDecorators(spkd, rpd, svd)
defaultSignMode, err := authsign.APISignModeToInternal(anteTxConfig.SignModeHandler().DefaultMode())
require.NoError(t, err)

Expand Down Expand Up @@ -468,12 +471,15 @@ func runSigDecorators(t *testing.T, params types.Params, _ bool, privs ...crypto

spkd := sdkante.NewSetPubKeyDecorator(suite.AccountKeeper)
svgc := sdkante.NewSigGasConsumeDecorator(suite.AccountKeeper, sdkante.DefaultSigVerificationGasConsumer)
svd := customante.NewSigVerificationDecorator(
rpd := customante.NewReplayProtectionDecorator(
suite.AccountKeeper,
suite.AccountplusKeeper,
)
svd := customante.NewSigVerificationDecorator(
suite.AccountKeeper,
suite.ClientCtx.TxConfig.SignModeHandler(),
)
antehandler := sdk.ChainAnteDecorators(spkd, svgc, svd)
antehandler := sdk.ChainAnteDecorators(spkd, svgc, rpd, svd)

txBytes, err := suite.ClientCtx.TxConfig.TxEncoder()(tx)
require.NoError(t, err)
Expand Down

0 comments on commit 900984e

Please sign in to comment.