diff --git a/pkg/user/pruning_test.go b/pkg/user/pruning_test.go new file mode 100644 index 0000000000..14465d7e64 --- /dev/null +++ b/pkg/user/pruning_test.go @@ -0,0 +1,51 @@ +package user + +import ( + "fmt" + "testing" + "time" + + "github.com/stretchr/testify/require" +) + +func TestPruningInTxTracker(t *testing.T) { + txClient := &TxClient{ + txTracker: make(map[string]txInfo), + } + numTransactions := 10 + + // Add 10 transactions to the tracker that are 10 minutes old + var txsToBePruned int + var txsNotReadyToBePruned int + for i := 0; i < numTransactions; i++ { + // 5 transactions will be pruned + if i%2 == 0 { + txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{ + signer: "signer" + fmt.Sprint(i), + sequence: uint64(i), + timestamp: time.Now(). + Add(-10 * time.Minute), + } + txsToBePruned++ + } else { + txClient.txTracker["tx"+fmt.Sprint(i)] = txInfo{ + signer: "signer" + fmt.Sprint(i), + sequence: uint64(i), + timestamp: time.Now(). + Add(-5 * time.Minute), + } + txsNotReadyToBePruned++ + } + } + + txTrackerBeforePruning := len(txClient.txTracker) + + // All transactions were indexed + require.Equal(t, numTransactions, len(txClient.txTracker)) + txClient.pruneTxTracker() + // Prunes the transactions that are 10 minutes old + // 5 transactions will be pruned + require.Equal(t, txTrackerBeforePruning-txsToBePruned, txsToBePruned) + // 5 transactions will not be pruned + require.Equal(t, len(txClient.txTracker), txsNotReadyToBePruned) +} diff --git a/pkg/user/tx_client.go b/pkg/user/tx_client.go index a403a91991..ce2f7b8933 100644 --- a/pkg/user/tx_client.go +++ b/pkg/user/tx_client.go @@ -12,7 +12,6 @@ import ( "time" "github.com/celestiaorg/go-square/v2/share" - blobtx "github.com/celestiaorg/go-square/v2/tx" "github.com/cosmos/cosmos-sdk/client" nodeservice "github.com/cosmos/cosmos-sdk/client/grpc/node" "github.com/cosmos/cosmos-sdk/client/grpc/tmservice" @@ -27,7 +26,6 @@ import ( "github.com/celestiaorg/celestia-app/v3/app" "github.com/celestiaorg/celestia-app/v3/app/encoding" - apperrors "github.com/celestiaorg/celestia-app/v3/app/errors" "github.com/celestiaorg/celestia-app/v3/app/grpc/tx" "github.com/celestiaorg/celestia-app/v3/pkg/appconsts" "github.com/celestiaorg/celestia-app/v3/x/blob/types" @@ -35,12 +33,21 @@ import ( ) const ( - DefaultPollTime = 3 * time.Second - DefaultGasMultiplier float64 = 1.1 + DefaultPollTime = 3 * time.Second + DefaultGasMultiplier float64 = 1.1 + txTrackerPruningInterval = 10 * time.Minute ) type Option func(client *TxClient) +// txInfo is a struct that holds the sequence and the signer of a transaction +// in the local tx pool. +type txInfo struct { + sequence uint64 + signer string + timestamp time.Time +} + // TxResponse is a response from the chain after // a transaction has been submitted. type TxResponse struct { @@ -137,6 +144,9 @@ type TxClient struct { defaultGasPrice float64 defaultAccount string defaultAddress sdktypes.AccAddress + // txTracker maps the tx hash to the Sequence and signer of the transaction + // that was submitted to the chain + txTracker map[string]txInfo } // NewTxClient returns a new signer using the provided keyring @@ -169,6 +179,7 @@ func NewTxClient( defaultGasPrice: appconsts.DefaultMinGasPrice, defaultAccount: records[0].Name, defaultAddress: addr, + txTracker: make(map[string]txInfo), } for _, opt := range options { @@ -302,6 +313,12 @@ func (client *TxClient) SubmitTx(ctx context.Context, msgs []sdktypes.Msg, opts func (client *TxClient) BroadcastTx(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (*sdktypes.TxResponse, error) { client.mtx.Lock() defer client.mtx.Unlock() + + // prune transactions that are older than 10 minutes + // pruning has to be done in broadcast, since users + // might not always call ConfirmTx(). + client.pruneTxTracker() + account, err := client.getAccountNameFromMsgs(msgs) if err != nil { return nil, err @@ -368,23 +385,20 @@ func (client *TxClient) broadcastTx(ctx context.Context, txBytes []byte, signer return nil, err } if resp.TxResponse.Code != abci.CodeTypeOK { - if apperrors.IsNonceMismatchCode(resp.TxResponse.Code) { - // query the account to update the sequence number on-chain for the account - _, seqNum, err := QueryAccount(ctx, client.grpc, client.registry, client.signer.accounts[signer].address) - if err != nil { - return nil, fmt.Errorf("querying account for new sequence number: %w\noriginal tx response: %s", err, resp.TxResponse.RawLog) - } - if err := client.signer.SetSequence(signer, seqNum); err != nil { - return nil, fmt.Errorf("setting sequence: %w", err) - } - return client.retryBroadcastingTx(ctx, txBytes) - } broadcastTxErr := &BroadcastTxError{ TxHash: resp.TxResponse.TxHash, Code: resp.TxResponse.Code, ErrorLog: resp.TxResponse.RawLog, } - return resp.TxResponse, broadcastTxErr + return nil, broadcastTxErr + } + + // save the sequence and signer of the transaction in the local txTracker + // before the sequence is incremented + client.txTracker[resp.TxResponse.TxHash] = txInfo{ + sequence: client.signer.accounts[signer].Sequence(), + signer: signer, + timestamp: time.Now(), } // after the transaction has been submitted, we can increment the @@ -395,62 +409,13 @@ func (client *TxClient) broadcastTx(ctx context.Context, txBytes []byte, signer return resp.TxResponse, nil } -// retryBroadcastingTx creates a new transaction by copying over an existing transaction but creates a new signature with the -// new sequence number. It then calls `broadcastTx` and attempts to submit the transaction -func (client *TxClient) retryBroadcastingTx(ctx context.Context, txBytes []byte) (*sdktypes.TxResponse, error) { - blobTx, isBlobTx, err := blobtx.UnmarshalBlobTx(txBytes) - if isBlobTx { - // only check the error if the bytes are supposed to be of type blob tx - if err != nil { - return nil, err +// pruneTxTracker removes transactions from the local tx tracker that are older than 10 minutes +func (client *TxClient) pruneTxTracker() { + for hash, txInfo := range client.txTracker { + if time.Since(txInfo.timestamp) >= txTrackerPruningInterval { + delete(client.txTracker, hash) } - txBytes = blobTx.Tx - } - tx, err := client.signer.DecodeTx(txBytes) - if err != nil { - return nil, err } - - opts := make([]TxOption, 0) - if granter := tx.FeeGranter(); granter != nil { - opts = append(opts, SetFeeGranter(granter)) - } - if payer := tx.FeePayer(); payer != nil { - opts = append(opts, SetFeePayer(payer)) - } - if memo := tx.GetMemo(); memo != "" { - opts = append(opts, SetMemo(memo)) - } - if fee := tx.GetFee(); fee != nil { - opts = append(opts, SetFee(fee.AmountOf(appconsts.BondDenom).Uint64())) - } - if gas := tx.GetGas(); gas > 0 { - opts = append(opts, SetGasLimit(gas)) - } - - txBuilder, err := client.signer.txBuilder(tx.GetMsgs(), opts...) - if err != nil { - return nil, err - } - signer, _, err := client.signer.signTransaction(txBuilder) - if err != nil { - return nil, fmt.Errorf("resigning transaction: %w", err) - } - - newTxBytes, err := client.signer.EncodeTx(txBuilder.GetTx()) - if err != nil { - return nil, err - } - - // rewrap the blob tx if it was originally a blob tx - if isBlobTx { - newTxBytes, err = blobtx.MarshalBlobTx(newTxBytes, blobTx.Blobs...) - if err != nil { - return nil, err - } - } - - return client.broadcastTx(ctx, newTxBytes, signer) } // ConfirmTx periodically pings the provided node for the commitment of a transaction by its @@ -468,40 +433,68 @@ func (client *TxClient) ConfirmTx(ctx context.Context, txHash string) (*TxRespon return nil, err } - if resp != nil { - switch resp.Status { - case core.TxStatusPending: - // Continue polling if the transaction is still pending - select { - case <-ctx.Done(): - return nil, ctx.Err() - case <-pollTicker.C: - continue - } - case core.TxStatusCommitted: - txResponse := &TxResponse{ - Height: resp.Height, - TxHash: txHash, - Code: resp.ExecutionCode, - } - if resp.ExecutionCode != abci.CodeTypeOK { - executionErr := &ExecutionError{ - TxHash: txHash, - Code: resp.ExecutionCode, - ErrorLog: resp.Error, - } - return nil, executionErr + switch resp.Status { + case core.TxStatusPending: + // Continue polling if the transaction is still pending + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-pollTicker.C: + continue + } + case core.TxStatusCommitted: + txResponse := &TxResponse{ + Height: resp.Height, + TxHash: txHash, + Code: resp.ExecutionCode, + } + if resp.ExecutionCode != abci.CodeTypeOK { + executionErr := &ExecutionError{ + TxHash: txHash, + Code: resp.ExecutionCode, + ErrorLog: resp.Error, } - return txResponse, nil - case core.TxStatusEvicted: - return nil, fmt.Errorf("tx was evicted from the mempool") - default: - return nil, fmt.Errorf("unknown tx: %s", txHash) + client.deleteFromTxTracker(txHash) + return nil, executionErr } + client.deleteFromTxTracker(txHash) + return txResponse, nil + case core.TxStatusEvicted: + return nil, client.handleEvictions(txHash) + default: + client.deleteFromTxTracker(txHash) + return nil, fmt.Errorf("transaction with hash %s not found; it was likely rejected", txHash) } } } +// handleEvictions handles the scenario where a transaction is evicted from the mempool. +// It removes the evicted transaction from the local tx tracker without incrementing +// the signer's sequence. +func (client *TxClient) handleEvictions(txHash string) error { + client.mtx.Lock() + defer client.mtx.Unlock() + // Get transaction from the local tx tracker + txInfo, exists := client.txTracker[txHash] + if !exists { + return fmt.Errorf("tx: %s not found in tx client txTracker; likely failed during broadcast", txHash) + } + // The sequence should be rolled back to the sequence of the transaction that was evicted to be + // ready for resubmission. All transactions with a later nonce will be kicked by the nodes tx pool. + if err := client.signer.SetSequence(txInfo.signer, txInfo.sequence); err != nil { + return fmt.Errorf("setting sequence: %w", err) + } + delete(client.txTracker, txHash) + return fmt.Errorf("tx was evicted from the mempool") +} + +// deleteFromTxTracker safely deletes a transaction from the local tx tracker. +func (client *TxClient) deleteFromTxTracker(txHash string) { + client.mtx.Lock() + defer client.mtx.Unlock() + delete(client.txTracker, txHash) +} + // EstimateGas simulates the transaction, calculating the amount of gas that was consumed during execution. The final // result will be multiplied by gasMultiplier(that is set in TxClient) func (client *TxClient) EstimateGas(ctx context.Context, msgs []sdktypes.Msg, opts ...TxOption) (uint64, error) { @@ -576,6 +569,7 @@ func (client *TxClient) checkAccountLoaded(ctx context.Context, account string) if err != nil { return fmt.Errorf("retrieving address from keyring: %w", err) } + // FIXME: have a less trusting way of getting the account number and sequence accNum, sequence, err := QueryAccount(ctx, client.grpc, client.registry, addr) if err != nil { return fmt.Errorf("querying account %s: %w", account, err) @@ -604,6 +598,14 @@ func (client *TxClient) getAccountNameFromMsgs(msgs []sdktypes.Msg) (string, err return record.Name, nil } +// GetTxFromTxTracker gets transaction info from the tx client's local tx tracker by its hash +func (client *TxClient) GetTxFromTxTracker(hash string) (sequence uint64, signer string, exists bool) { + client.mtx.Lock() + defer client.mtx.Unlock() + txInfo, exists := client.txTracker[hash] + return txInfo.sequence, txInfo.signer, exists +} + // Signer exposes the tx clients underlying signer func (client *TxClient) Signer() *Signer { return client.signer diff --git a/pkg/user/tx_client_test.go b/pkg/user/tx_client_test.go index 1588c3a060..6ae0c2efa2 100644 --- a/pkg/user/tx_client_test.go +++ b/pkg/user/tx_client_test.go @@ -5,21 +5,21 @@ import ( "testing" "time" + "github.com/celestiaorg/celestia-app/v3/app" + "github.com/celestiaorg/celestia-app/v3/app/encoding" + "github.com/celestiaorg/celestia-app/v3/pkg/appconsts" + "github.com/celestiaorg/celestia-app/v3/pkg/user" + "github.com/celestiaorg/celestia-app/v3/test/util/blobfactory" + "github.com/celestiaorg/celestia-app/v3/test/util/testnode" sdk "github.com/cosmos/cosmos-sdk/types" sdktx "github.com/cosmos/cosmos-sdk/types/tx" + "github.com/cosmos/cosmos-sdk/x/authz" bank "github.com/cosmos/cosmos-sdk/x/bank/types" "github.com/stretchr/testify/require" "github.com/stretchr/testify/suite" abci "github.com/tendermint/tendermint/abci/types" "github.com/tendermint/tendermint/libs/rand" - - "github.com/celestiaorg/celestia-app/v3/app" - "github.com/celestiaorg/celestia-app/v3/app/encoding" - "github.com/celestiaorg/celestia-app/v3/pkg/appconsts" - "github.com/celestiaorg/celestia-app/v3/pkg/user" - "github.com/celestiaorg/celestia-app/v3/test/util/blobfactory" - "github.com/celestiaorg/celestia-app/v3/test/util/testnode" ) func TestTxClientTestSuite(t *testing.T) { @@ -39,15 +39,7 @@ type TxClientTestSuite struct { } func (suite *TxClientTestSuite) SetupSuite() { - suite.encCfg = encoding.MakeConfig(app.ModuleEncodingRegisters...) - config := testnode.DefaultConfig(). - WithFundedAccounts("a", "b", "c"). - WithAppCreator(testnode.CustomAppCreator("0utia")) - suite.ctx, _, _ = testnode.NewNetwork(suite.T(), config) - _, err := suite.ctx.WaitForHeight(1) - suite.Require().NoError(err) - suite.txClient, err = user.SetupTxClient(suite.ctx.GoContext(), suite.ctx.Keyring, suite.ctx.GRPCClient, suite.encCfg, user.WithGasMultiplier(1.2)) - suite.Require().NoError(err) + suite.encCfg, suite.txClient, suite.ctx = setupTxClient(suite.T(), testnode.DefaultTendermintConfig().Mempool.TTLDuration) suite.serviceClient = sdktx.NewServiceClient(suite.ctx.GRPCClient) } @@ -162,9 +154,11 @@ func (suite *TxClientTestSuite) TestConfirmTx() { ctx, cancel := context.WithTimeout(suite.ctx.GoContext(), time.Second) defer cancel() + seqBeforeBroadcast := suite.txClient.Signer().Account(suite.txClient.DefaultAccountName()).Sequence() msg := bank.NewMsgSend(suite.txClient.DefaultAddress(), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) resp, err := suite.txClient.BroadcastTx(ctx, []sdk.Msg{msg}) require.NoError(t, err) + assertTxInTxTracker(t, suite.txClient, resp.TxHash, suite.txClient.DefaultAccountName(), seqBeforeBroadcast) _, err = suite.txClient.ConfirmTx(ctx, resp.TxHash) require.Error(t, err) @@ -174,48 +168,90 @@ func (suite *TxClientTestSuite) TestConfirmTx() { t.Run("should error when tx is not found", func(t *testing.T) { ctx, cancel := context.WithTimeout(suite.ctx.GoContext(), 5*time.Second) defer cancel() - _, err := suite.txClient.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") - require.Contains(t, err.Error(), "unknown tx: E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") + resp, err := suite.txClient.ConfirmTx(ctx, "E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728") + require.Contains(t, err.Error(), "transaction with hash E32BD15CAF57AF15D17B0D63CF4E63A9835DD1CEBB059C335C79586BC3013728 not found; it was likely rejected") + require.Nil(t, resp) }) t.Run("should return error log when execution fails", func(t *testing.T) { + seqBeforeBroadcast := suite.txClient.Signer().Account(suite.txClient.DefaultAccountName()).Sequence() innerMsg := bank.NewMsgSend(testnode.RandomAddress().(sdk.AccAddress), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) msg := authz.NewMsgExec(suite.txClient.DefaultAddress(), []sdk.Msg{innerMsg}) resp, err := suite.txClient.BroadcastTx(suite.ctx.GoContext(), []sdk.Msg{&msg}, fee, gas) require.NoError(t, err) - _, err = suite.txClient.ConfirmTx(suite.ctx.GoContext(), resp.TxHash) + assertTxInTxTracker(t, suite.txClient, resp.TxHash, suite.txClient.DefaultAccountName(), seqBeforeBroadcast) + + confirmTxResp, err := suite.txClient.ConfirmTx(suite.ctx.GoContext(), resp.TxHash) require.Error(t, err) require.Contains(t, err.Error(), "authorization not found") + require.Nil(t, confirmTxResp) + require.True(t, wasRemovedFromTxTracker(resp.TxHash, suite.txClient)) }) t.Run("should success when tx is found immediately", func(t *testing.T) { addr := suite.txClient.DefaultAddress() + seqBeforeBroadcast := suite.txClient.Signer().Account(suite.txClient.DefaultAccountName()).Sequence() msg := bank.NewMsgSend(addr, testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) resp, err := suite.txClient.BroadcastTx(suite.ctx.GoContext(), []sdk.Msg{msg}, fee, gas) require.NoError(t, err) - require.NotNil(t, resp) + require.Equal(t, resp.Code, abci.CodeTypeOK) + assertTxInTxTracker(t, suite.txClient, resp.TxHash, suite.txClient.DefaultAccountName(), seqBeforeBroadcast) + ctx, cancel := context.WithTimeout(suite.ctx.GoContext(), 30*time.Second) defer cancel() confirmTxResp, err := suite.txClient.ConfirmTx(ctx, resp.TxHash) require.NoError(t, err) require.Equal(t, abci.CodeTypeOK, confirmTxResp.Code) + require.True(t, wasRemovedFromTxTracker(resp.TxHash, suite.txClient)) }) t.Run("should error when tx is found with a non-zero error code", func(t *testing.T) { balance := suite.queryCurrentBalance(t) addr := suite.txClient.DefaultAddress() + seqBeforeBroadcast := suite.txClient.Signer().Account(suite.txClient.DefaultAccountName()).Sequence() // Create a msg send with out of balance, ensure this tx fails msg := bank.NewMsgSend(addr, testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 1+balance))) resp, err := suite.txClient.BroadcastTx(suite.ctx.GoContext(), []sdk.Msg{msg}, fee, gas) require.NoError(t, err) - require.NotNil(t, resp) - _, err = suite.txClient.ConfirmTx(suite.ctx.GoContext(), resp.TxHash) + require.Equal(t, resp.Code, abci.CodeTypeOK) + assertTxInTxTracker(t, suite.txClient, resp.TxHash, suite.txClient.DefaultAccountName(), seqBeforeBroadcast) + + confirmTxResp, err := suite.txClient.ConfirmTx(suite.ctx.GoContext(), resp.TxHash) require.Error(t, err) + require.Nil(t, confirmTxResp) code := err.(*user.ExecutionError).Code require.NotEqual(t, abci.CodeTypeOK, code) + require.True(t, wasRemovedFromTxTracker(resp.TxHash, suite.txClient)) }) } +func TestEvictions(t *testing.T) { + _, txClient, ctx := setupTxClient(t, 1*time.Nanosecond) + + fee := user.SetFee(1e6) + gas := user.SetGasLimit(1e6) + + // Keep submitting the transaction until we get the eviction error + sender := txClient.Signer().Account(txClient.DefaultAccountName()) + msg := bank.NewMsgSend(sender.Address(), testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) + var seqBeforeEviction uint64 + // Loop five times until the tx is evicted + for i := 0; i < 5; i++ { + seqBeforeEviction = sender.Sequence() + resp, err := txClient.BroadcastTx(ctx.GoContext(), []sdk.Msg{msg}, fee, gas) + require.NoError(t, err) + _, err = txClient.ConfirmTx(ctx.GoContext(), resp.TxHash) + if err != nil { + if err.Error() == "tx was evicted from the mempool" { + break + } + } + } + + seqAfterEviction := sender.Sequence() + require.Equal(t, seqBeforeEviction, seqAfterEviction) +} + func (suite *TxClientTestSuite) TestGasEstimation() { addr := suite.txClient.DefaultAddress() msg := bank.NewMsgSend(addr, testnode.RandomAddress().(sdk.AccAddress), sdk.NewCoins(sdk.NewInt64Coin(app.BondDenom, 10))) @@ -281,3 +317,36 @@ func (suite *TxClientTestSuite) queryCurrentBalance(t *testing.T) int64 { require.NoError(t, err) return balanceResp.Balances.AmountOf(app.BondDenom).Int64() } + +func wasRemovedFromTxTracker(txHash string, txClient *user.TxClient) bool { + seq, signer, exists := txClient.GetTxFromTxTracker(txHash) + return !exists && seq == 0 && signer == "" +} + +// asserts that a tx was indexed in the tx tracker and that the sequence does not increase +func assertTxInTxTracker(t *testing.T, txClient *user.TxClient, txHash string, expectedSigner string, seqBeforeBroadcast uint64) { + seqFromTxTracker, signer, exists := txClient.GetTxFromTxTracker(txHash) + require.True(t, exists) + require.Equal(t, expectedSigner, signer) + seqAfterBroadcast := txClient.Signer().Account(expectedSigner).Sequence() + // TxInfo is indexed before the nonce is increased + require.Equal(t, seqBeforeBroadcast, seqFromTxTracker) + // Successfully broadcast transaction increases the sequence + require.Equal(t, seqAfterBroadcast, seqBeforeBroadcast+1) +} + +func setupTxClient(t *testing.T, ttlDuration time.Duration) (encoding.Config, *user.TxClient, testnode.Context) { + encCfg := encoding.MakeConfig(app.ModuleEncodingRegisters...) + defaultTmConfig := testnode.DefaultTendermintConfig() + defaultTmConfig.Mempool.TTLDuration = ttlDuration + testnodeConfig := testnode.DefaultConfig(). + WithTendermintConfig(defaultTmConfig). + WithFundedAccounts("a", "b", "c"). + WithAppCreator(testnode.CustomAppCreator("0utia")) + ctx, _, _ := testnode.NewNetwork(t, testnodeConfig) + _, err := ctx.WaitForHeight(1) + require.NoError(t, err) + txClient, err := user.SetupTxClient(ctx.GoContext(), ctx.Keyring, ctx.GRPCClient, encCfg, user.WithGasMultiplier(1.2)) + require.NoError(t, err) + return encCfg, txClient, ctx +}