From 0d201dead39e690cd0fd4ad82a1ea0cbe9ee5025 Mon Sep 17 00:00:00 2001 From: yihuang Date: Wed, 4 Sep 2024 18:42:43 +0800 Subject: [PATCH 1/4] fix(mempool): data race in mempool prepare proposal handler (#21413) --- CHANGELOG.md | 3 + baseapp/abci_utils.go | 37 +++++++----- types/mempool/mempool.go | 6 +- types/mempool/noop.go | 9 +-- types/mempool/priority_nonce.go | 17 +++++- types/mempool/priority_nonce_test.go | 85 ++++++++++++++++++++++++++++ types/mempool/sender_nonce.go | 17 +++++- 7 files changed, 153 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index add0a5b3b1b8..757f475816f8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -51,9 +51,12 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i ### Bug Fixes * (baseapp) [#21256](https://github.com/cosmos/cosmos-sdk/pull/21256) Halt height will not commit the block indicated, meaning that if halt-height is set to 10, only blocks until 9 (included) will be committed. This is to go back to the original behavior before a change was introduced in v0.50.0. +* (baseapp) [#21413](https://github.com/cosmos/cosmos-sdk/pull/21413) Fix data race in sdk mempool. ### API Breaking Changes +* (baseapp) [#21413](https://github.com/cosmos/cosmos-sdk/pull/21413) Add `SelectBy` method to `Mempool` interface, which is thread-safe to use. + ### Deprecated * (types) [#21435](https://github.com/cosmos/cosmos-sdk/pull/21435) The `String()` method on `AccAddress`, `ValAddress` and `ConsAddress` have been deprecated. This is done because those are still using the deprecated global `sdk.Config`. Use an `address.Codec` instead. diff --git a/baseapp/abci_utils.go b/baseapp/abci_utils.go index da6adef5539d..4fa068b3c9a4 100644 --- a/baseapp/abci_utils.go +++ b/baseapp/abci_utils.go @@ -285,14 +285,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil } - iterator := h.mempool.Select(ctx, req.Txs) selectedTxsSignersSeqs := make(map[string]uint64) - var selectedTxsNums int - for iterator != nil { - memTx := iterator.Tx() + var ( + resError error + selectedTxsNums int + invalidTxs []sdk.Tx // invalid txs to be removed out of the loop to avoid dead lock + ) + h.mempool.SelectBy(ctx, req.Txs, func(memTx sdk.Tx) bool { signerData, err := h.signerExtAdapter.GetSigners(memTx) if err != nil { - return nil, err + // propagate the error to the caller + resError = err + return false } // If the signers aren't in selectedTxsSignersSeqs then we haven't seen them before @@ -316,8 +320,7 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan txSignersSeqs[signer.Signer.String()] = signer.Sequence } if !shouldAdd { - iterator = iterator.Next() - continue + return true } // NOTE: Since transaction verification was already executed in CheckTx, @@ -326,14 +329,11 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan // check again. txBz, err := h.txVerifier.PrepareProposalVerifyTx(memTx) if err != nil { - err := h.mempool.Remove(memTx) - if err != nil && !errors.Is(err, mempool.ErrTxNotFound) { - return nil, err - } + invalidTxs = append(invalidTxs, memTx) } else { stop := h.txSelector.SelectTxForProposal(ctx, uint64(req.MaxTxBytes), maxBlockGas, memTx, txBz) if stop { - break + return false } txsLen := len(h.txSelector.SelectedTxs(ctx)) @@ -354,7 +354,18 @@ func (h *DefaultProposalHandler) PrepareProposalHandler() sdk.PrepareProposalHan selectedTxsNums = txsLen } - iterator = iterator.Next() + return true + }) + + if resError != nil { + return nil, resError + } + + for _, tx := range invalidTxs { + err := h.mempool.Remove(tx) + if err != nil && !errors.Is(err, mempool.ErrTxNotFound) { + return nil, err + } } return &abci.PrepareProposalResponse{Txs: h.txSelector.SelectedTxs(ctx)}, nil diff --git a/types/mempool/mempool.go b/types/mempool/mempool.go index 7051c93e3146..4f8f82f16fa7 100644 --- a/types/mempool/mempool.go +++ b/types/mempool/mempool.go @@ -13,10 +13,12 @@ type Mempool interface { Insert(context.Context, sdk.Tx) error // Select returns an Iterator over the app-side mempool. If txs are specified, - // then they shall be incorporated into the Iterator. The Iterator must be - // closed by the caller. + // then they shall be incorporated into the Iterator. The Iterator is not thread-safe to use. Select(context.Context, [][]byte) Iterator + // SelectBy use callback to iterate over the mempool, it's thread-safe to use. + SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) + // CountTx returns the number of transactions currently in the mempool. CountTx() int diff --git a/types/mempool/noop.go b/types/mempool/noop.go index 73c12639d1d6..33c002080f82 100644 --- a/types/mempool/noop.go +++ b/types/mempool/noop.go @@ -16,7 +16,8 @@ var _ Mempool = (*NoOpMempool)(nil) // is FIFO-ordered by default. type NoOpMempool struct{} -func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil } -func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil } -func (NoOpMempool) CountTx() int { return 0 } -func (NoOpMempool) Remove(sdk.Tx) error { return nil } +func (NoOpMempool) Insert(context.Context, sdk.Tx) error { return nil } +func (NoOpMempool) Select(context.Context, [][]byte) Iterator { return nil } +func (NoOpMempool) SelectBy(context.Context, [][]byte, func(sdk.Tx) bool) {} +func (NoOpMempool) CountTx() int { return 0 } +func (NoOpMempool) Remove(sdk.Tx) error { return nil } diff --git a/types/mempool/priority_nonce.go b/types/mempool/priority_nonce.go index a927693410ef..f081e2b413db 100644 --- a/types/mempool/priority_nonce.go +++ b/types/mempool/priority_nonce.go @@ -351,9 +351,13 @@ func (i *PriorityNonceIterator[C]) Tx() sdk.Tx { // // NOTE: It is not safe to use this iterator while removing transactions from // the underlying mempool. -func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterator { +func (mp *PriorityNonceMempool[C]) Select(ctx context.Context, txs [][]byte) Iterator { mp.mtx.Lock() defer mp.mtx.Unlock() + return mp.doSelect(ctx, txs) +} + +func (mp *PriorityNonceMempool[C]) doSelect(_ context.Context, _ [][]byte) Iterator { if mp.priorityIndex.Len() == 0 { return nil } @@ -368,6 +372,17 @@ func (mp *PriorityNonceMempool[C]) Select(_ context.Context, _ [][]byte) Iterato return iterator.iteratePriority() } +// SelectBy will hold the mutex during the iteration, callback returns if continue. +func (mp *PriorityNonceMempool[C]) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) { + mp.mtx.Lock() + defer mp.mtx.Unlock() + + iter := mp.doSelect(ctx, txs) + for iter != nil && callback(iter.Tx()) { + iter = iter.Next() + } +} + type reorderKey[C comparable] struct { deleteKey txMeta[C] insertKey txMeta[C] diff --git a/types/mempool/priority_nonce_test.go b/types/mempool/priority_nonce_test.go index 4b1c27c1808b..3bfd7e4ba86c 100644 --- a/types/mempool/priority_nonce_test.go +++ b/types/mempool/priority_nonce_test.go @@ -1,9 +1,11 @@ package mempool_test import ( + "context" "fmt" "math" "math/rand" + "sync" "testing" "time" @@ -395,6 +397,89 @@ func (s *MempoolTestSuite) TestIterator() { } } +func (s *MempoolTestSuite) TestIteratorConcurrency() { + t := s.T() + ctx := sdk.NewContext(nil, false, log.NewNopLogger()) + accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 2) + sa := accounts[0].Address + sb := accounts[1].Address + + tests := []struct { + txs []txSpec + fail bool + }{ + { + txs: []txSpec{ + {p: 20, n: 1, a: sa}, + {p: 15, n: 1, a: sb}, + {p: 6, n: 2, a: sa}, + {p: 21, n: 4, a: sa}, + {p: 8, n: 2, a: sb}, + }, + }, + { + txs: []txSpec{ + {p: 20, n: 1, a: sa}, + {p: 15, n: 1, a: sb}, + {p: 6, n: 2, a: sa}, + {p: 21, n: 4, a: sa}, + {p: math.MinInt64, n: 2, a: sb}, + }, + }, + } + + for i, tt := range tests { + t.Run(fmt.Sprintf("case %d", i), func(t *testing.T) { + pool := mempool.DefaultPriorityMempool() + + // create test txs and insert into mempool + for i, ts := range tt.txs { + tx := testTx{id: i, priority: int64(ts.p), nonce: uint64(ts.n), address: ts.a} + c := ctx.WithPriority(tx.priority) + err := pool.Insert(c, tx) + require.NoError(t, err) + } + + // iterate through txs + stdCtx, cancel := context.WithCancel(context.Background()) + var wg sync.WaitGroup + wg.Add(1) + go func() { + defer wg.Done() + + id := len(tt.txs) + for { + select { + case <-stdCtx.Done(): + return + default: + id++ + tx := testTx{id: id, priority: int64(rand.Intn(100)), nonce: uint64(id), address: sa} + c := ctx.WithPriority(tx.priority) + err := pool.Insert(c, tx) + require.NoError(t, err) + } + } + }() + + var i int + pool.SelectBy(ctx, nil, func(memTx sdk.Tx) bool { + tx := memTx.(testTx) + if tx.id < len(tt.txs) { + require.Equal(t, tt.txs[tx.id].p, int(tx.priority)) + require.Equal(t, tt.txs[tx.id].n, int(tx.nonce)) + require.Equal(t, tt.txs[tx.id].a, tx.address) + i++ + } + return i < len(tt.txs) + }) + require.Equal(t, i, len(tt.txs)) + cancel() + wg.Wait() + }) + } +} + func (s *MempoolTestSuite) TestPriorityTies() { ctx := sdk.NewContext(nil, false, log.NewNopLogger()) accounts := simtypes.RandomAccounts(rand.New(rand.NewSource(0)), 3) diff --git a/types/mempool/sender_nonce.go b/types/mempool/sender_nonce.go index d69d5b6f6c18..ea9807c31ea0 100644 --- a/types/mempool/sender_nonce.go +++ b/types/mempool/sender_nonce.go @@ -158,9 +158,13 @@ func (snm *SenderNonceMempool) Insert(_ context.Context, tx sdk.Tx) error { // // NOTE: It is not safe to use this iterator while removing transactions from // the underlying mempool. -func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator { +func (snm *SenderNonceMempool) Select(ctx context.Context, txs [][]byte) Iterator { snm.mtx.Lock() defer snm.mtx.Unlock() + return snm.doSelect(ctx, txs) +} + +func (snm *SenderNonceMempool) doSelect(_ context.Context, _ [][]byte) Iterator { var senders []string senderCursors := make(map[string]*skiplist.Element) @@ -188,6 +192,17 @@ func (snm *SenderNonceMempool) Select(_ context.Context, _ [][]byte) Iterator { return iter.Next() } +// SelectBy will hold the mutex during the iteration, callback returns if continue. +func (snm *SenderNonceMempool) SelectBy(ctx context.Context, txs [][]byte, callback func(sdk.Tx) bool) { + snm.mtx.Lock() + defer snm.mtx.Unlock() + + iter := snm.doSelect(ctx, txs) + for iter != nil && callback(iter.Tx()) { + iter = iter.Next() + } +} + // CountTx returns the total count of txs in the mempool. func (snm *SenderNonceMempool) CountTx() int { snm.mtx.Lock() From 6ffa71abd3096a64e6330a9a2bb56c08b05b9972 Mon Sep 17 00:00:00 2001 From: Julien Robert Date: Wed, 4 Sep 2024 12:57:07 +0200 Subject: [PATCH 2/4] chore: extract improvements from #21497 (#21506) Co-authored-by: Akhil Kumar P <36399231+akhilkumarpilli@users.noreply.github.com> --- UPGRADING.md | 52 ++++++------------------------- docs/learn/advanced/00-baseapp.md | 2 +- server/v2/cometbft/abci.go | 2 +- server/v2/types.go | 2 +- simapp/v2/app_config.go | 6 +++- x/distribution/keeper/abci.go | 1 - 6 files changed, 17 insertions(+), 48 deletions(-) diff --git a/UPGRADING.md b/UPGRADING.md index 449583638f4b..70eaeaa472c1 100644 --- a/UPGRADING.md +++ b/UPGRADING.md @@ -106,7 +106,7 @@ For non depinject users, simply call `RegisterLegacyAminoCodec` and `RegisterInt Additionally, thanks to the genesis simplification, as explained in [the genesis interface update](#genesis-interface), the module manager `InitGenesis` and `ExportGenesis` methods do not require the codec anymore. -##### GRPC-WEB +##### GRPC WEB Grpc-web embedded client has been removed from the server. If you would like to use grpc-web, you can use the [envoy proxy](https://www.envoyproxy.io/docs/envoy/latest/start/start). Here's how to set it up: @@ -347,6 +347,8 @@ Also, any usages of the interfaces `AnyUnpacker` and `UnpackInterfacesMessage` m #### `**all**` +All modules (expect `auth`) were spun out into their own `go.mod`. Replace their imports by `cosmossdk.io/x/{moduleName}`. + ##### Core API Core API has been introduced for modules since v0.47. With the deprecation of `sdk.Context`, we strongly recommend to use the `cosmossdk.io/core/appmodule` interfaces for the modules. This will allow the modules to work out of the box with server/v2 and baseapp, as well as limit their dependencies on the SDK. @@ -399,7 +401,7 @@ All modules using dependency injection must update their imports. ##### Params -Previous module migrations have been removed. It is required to migrate to v0.50 prior to upgrading to v0.51 for not missing any module migrations. +Previous module migrations have been removed. It is required to migrate to v0.50 prior to upgrading to v0.52 for not missing any module migrations. ##### Genesis Interface @@ -436,60 +438,24 @@ if err != nil { } ``` -#### `x/auth` - -Auth was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/auth` - -#### `x/authz` - -Authz was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/authz` +### `x/crisis` -#### `x/bank` - -Bank was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/bank` - -### `x/crsis` - -The Crisis module was removed due to it not being supported or functional any longer. +The `x/crisis` module was removed due to it not being supported or functional any longer. #### `x/distribution` -Distribution was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/distribution` - -The existing chains using x/distribution module needs to add the new x/protocolpool module. - -#### `x/group` - -Group was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/group` +Existing chains using `x/distribution` module must add the new `x/protocolpool` module. #### `x/gov` -Gov was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/gov` - Gov v1beta1 proposal handler has been changed to take in a `context.Context` instead of `sdk.Context`. This change was made to allow legacy proposals to be compatible with server/v2. If you wish to migrate to server/v2, you should update your proposal handler to take in a `context.Context` and use services. On the other hand, if you wish to keep using baseapp, simply unwrap the sdk context in your proposal handler. -#### `x/mint` - -Mint was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/mint` - -#### `x/slashing` - -Slashing was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/slashing` - -#### `x/staking` - -Staking was spun out into its own `go.mod`. To import it use `cosmossdk.io/x/staking` - -#### `x/params` - -A standalone Go module was created and it is accessible at "cosmossdk.io/x/params". - #### `x/protocolpool` -Introducing a new `x/protocolpool` module to handle community pool funds. Its store must be added while upgrading to v0.51.x. +Introducing a new `x/protocolpool` module to handle community pool funds. Its store must be added while upgrading to v0.52.x. Example: @@ -506,7 +472,7 @@ func (app SimApp) RegisterUpgradeHandlers() { } ``` -Add `x/protocolpool` store while upgrading to v0.51.x: +Add `x/protocolpool` store while upgrading to v0.52.x: ```go storetypes.StoreUpgrades{ diff --git a/docs/learn/advanced/00-baseapp.md b/docs/learn/advanced/00-baseapp.md index 1a7cd28fdc96..20968f91bd0f 100644 --- a/docs/learn/advanced/00-baseapp.md +++ b/docs/learn/advanced/00-baseapp.md @@ -205,7 +205,7 @@ newly committed state and `finalizeBlockState` is set to `nil` to be reset on `F During `InitChain`, the `RequestInitChain` provides `ConsensusParams` which contains parameters related to block execution such as maximum gas and size in addition to evidence parameters. If these parameters are non-nil, they are set in the BaseApp's `ParamStore`. Behind the scenes, the `ParamStore` -is managed by an `x/consensus_params` module. This allows the parameters to be tweaked via +is managed by an `x/consensus` module. This allows the parameters to be tweaked via on-chain governance. ## Service Routers diff --git a/server/v2/cometbft/abci.go b/server/v2/cometbft/abci.go index c30a38ada250..06c1e43e5ecb 100644 --- a/server/v2/cometbft/abci.go +++ b/server/v2/cometbft/abci.go @@ -67,7 +67,7 @@ type Consensus[T transaction.Tx] struct { func NewConsensus[T transaction.Tx]( logger log.Logger, appName string, - consensusAuthority string, + consensusAuthority string, // TODO remove app *appmanager.AppManager[T], mp mempool.Mempool[T], indexedEvents map[string]struct{}, diff --git a/server/v2/types.go b/server/v2/types.go index 3979f99f48db..f25dbb8ab388 100644 --- a/server/v2/types.go +++ b/server/v2/types.go @@ -16,7 +16,7 @@ type AppI[T transaction.Tx] interface { Name() string InterfaceRegistry() server.InterfaceRegistry GetAppManager() *appmanager.AppManager[T] - GetConsensusAuthority() string + GetConsensusAuthority() string // TODO remove GetGPRCMethodsToMessageMap() map[string]func() gogoproto.Message GetStore() any } diff --git a/simapp/v2/app_config.go b/simapp/v2/app_config.go index 0cd6cac6b879..bd60f1a0aa56 100644 --- a/simapp/v2/app_config.go +++ b/simapp/v2/app_config.go @@ -138,6 +138,10 @@ var ( ModuleName: authtypes.ModuleName, KvStoreKey: "acc", }, + { + ModuleName: accounts.ModuleName, + KvStoreKey: accounts.StoreKey, + }, }, // NOTE: The genutils module must occur after staking so that pools are // properly initialized with tokens from genesis accounts. @@ -260,7 +264,7 @@ var ( { Name: consensustypes.ModuleName, Config: appconfig.WrapAny(&consensusmodulev1.Module{ - Authority: "consensus", + Authority: "consensus", // TODO remove. }), }, { diff --git a/x/distribution/keeper/abci.go b/x/distribution/keeper/abci.go index 64cb205f5c11..5831ee9e1b9d 100644 --- a/x/distribution/keeper/abci.go +++ b/x/distribution/keeper/abci.go @@ -10,7 +10,6 @@ import ( // BeginBlocker sets the proposer for determining distribution during endblock // and distribute rewards for the previous block. -// TODO: use context.Context after including the comet service func (k Keeper) BeginBlocker(ctx context.Context) error { defer telemetry.ModuleMeasureSince(types.ModuleName, telemetry.Now(), telemetry.MetricKeyBeginBlocker) From 4b78f15f650798bb4b21cbc43fe80207b976f140 Mon Sep 17 00:00:00 2001 From: Reece Williams <31943163+Reecepbcups@users.noreply.github.com> Date: Wed, 4 Sep 2024 04:33:09 -0700 Subject: [PATCH 3/4] feat(x/genutil)!: bulk add genesis accounts (#21372) Co-authored-by: Julien Robert --- CHANGELOG.md | 9 +- x/genutil/client/cli/commands.go | 1 + x/genutil/client/cli/genaccount.go | 93 +++++++++++- x/genutil/client/cli/genaccount_test.go | 169 +++++++++++++++++++++ x/genutil/genaccounts.go | 187 +++++++++++++----------- 5 files changed, 368 insertions(+), 91 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 757f475816f8..b99f57289593 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -43,6 +43,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i ### Features * (baseapp) [#20291](https://github.com/cosmos/cosmos-sdk/pull/20291) Simulate nested messages. +* (cli) [#21372](https://github.com/cosmos/cosmos-sdk/pull/21372) Add a `bulk-add-genesis-account` genesis command to add many genesis accounts at once. ### Improvements @@ -151,7 +152,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i * (client) [#17215](https://github.com/cosmos/cosmos-sdk/pull/17215) `server.StartCmd`,`server.ExportCmd`,`server.NewRollbackCmd`,`pruning.Cmd`,`genutilcli.InitCmd`,`genutilcli.GenTxCmd`,`genutilcli.CollectGenTxsCmd`,`genutilcli.AddGenesisAccountCmd`, do not take a home directory anymore. It is inferred from the root command. * (client) [#17259](https://github.com/cosmos/cosmos-sdk/pull/17259) Remove deprecated `clientCtx.PrintObjectLegacy`. Use `clientCtx.PrintProto` or `clientCtx.PrintRaw` instead. * (types) [#17348](https://github.com/cosmos/cosmos-sdk/pull/17348) Remove the `WrapServiceResult` function. - * The `*sdk.Result` returned by the msg server router will not contain the `.Data` field. + * The `*sdk.Result` returned by the msg server router will not contain the `.Data` field. * (types) [#17426](https://github.com/cosmos/cosmos-sdk/pull/17426) `NewContext` does not take a `cmtproto.Header{}` any longer. * `WithChainID` / `WithBlockHeight` / `WithBlockHeader` must be used to set values on the context * (client/keys) [#17503](https://github.com/cosmos/cosmos-sdk/pull/17503) `clientkeys.NewKeyOutput`, `MkConsKeyOutput`, `MkValKeyOutput`, `MkAccKeyOutput`, `MkAccKeysOutput` now take their corresponding address codec instead of using the global SDK config. @@ -205,7 +206,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i * (x/crisis) [#20043](https://github.com/cosmos/cosmos-sdk/pull/20043) Changed `NewMsgVerifyInvariant` to accept a string as argument instead of an `AccAddress`. * (x/simulation)[#20056](https://github.com/cosmos/cosmos-sdk/pull/20056) `SimulateFromSeed` now takes an address codec as argument. * (server) [#20140](https://github.com/cosmos/cosmos-sdk/pull/20140) Remove embedded grpc-web proxy in favor of standalone grpc-web proxy. [Envoy Proxy](https://www.envoyproxy.io/docs/envoy/latest/start/start) -* (client) [#20255](https://github.com/cosmos/cosmos-sdk/pull/20255) Use comet proofOp proto type instead of sdk version to avoid needing to translate to later be proven in the merkle proof runtime. +* (client) [#20255](https://github.com/cosmos/cosmos-sdk/pull/20255) Use comet proofOp proto type instead of sdk version to avoid needing to translate to later be proven in the merkle proof runtime. * (types)[#20369](https://github.com/cosmos/cosmos-sdk/pull/20369) The signature of `HasAminoCodec` has changed to accept a `core/legacy.Amino` interface instead of `codec.LegacyAmino`. * (server) [#20422](https://github.com/cosmos/cosmos-sdk/pull/20422) Deprecated `ServerContext`. To get `cmtcfg.Config` from cmd, use `client.GetCometConfigFromCmd(cmd)` instead of `server.GetServerContextFromCmd(cmd).Config` * (x/genutil) [#20740](https://github.com/cosmos/cosmos-sdk/pull/20740) Update `genutilcli.Commands` and `genutilcli.CommandsWithCustomMigrationMap` to take the genesis module and abstract the module manager. @@ -217,7 +218,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i ### Client Breaking Changes -* (runtime) [#19040](https://github.com/cosmos/cosmos-sdk/pull/19040) Simplify app config implementation and deprecate `/cosmos/app/v1alpha1/config` query. +* (runtime) [#19040](https://github.com/cosmos/cosmos-sdk/pull/19040) Simplify app config implementation and deprecate `/cosmos/app/v1alpha1/config` query. ### CLI Breaking Changes @@ -281,7 +282,7 @@ Every module contains its own CHANGELOG.md. Please refer to the module you are i * (types) [#19759](https://github.com/cosmos/cosmos-sdk/pull/19759) Align SignerExtractionAdapter in PriorityNonceMempool Remove. * (client) [#19870](https://github.com/cosmos/cosmos-sdk/pull/19870) Add new query command `wait-tx`. Alias `event-query-tx-for` to `wait-tx` for backward compatibility. -### Improvements +### Improvements * (telemetry) [#19903](https://github.com/cosmos/cosmos-sdk/pull/19903) Conditionally emit metrics based on enablement. * **Introduction of `Now` Function**: Added a new function called `Now` to the telemetry package. It returns the current system time if telemetry is enabled, or a zero time if telemetry is not enabled. diff --git a/x/genutil/client/cli/commands.go b/x/genutil/client/cli/commands.go index 6e1415fe0796..00042ec69b8d 100644 --- a/x/genutil/client/cli/commands.go +++ b/x/genutil/client/cli/commands.go @@ -39,6 +39,7 @@ func CommandsWithCustomMigrationMap(genutilModule genutil.AppModule, genMM genes CollectGenTxsCmd(genutilModule.GenTxValidator()), ValidateGenesisCmd(genMM), AddGenesisAccountCmd(), + AddBulkGenesisAccountCmd(), ExportCmd(appExport), ) diff --git a/x/genutil/client/cli/genaccount.go b/x/genutil/client/cli/genaccount.go index 34acef113e2e..938e711b3aca 100644 --- a/x/genutil/client/cli/genaccount.go +++ b/x/genutil/client/cli/genaccount.go @@ -2,7 +2,9 @@ package cli import ( "bufio" + "encoding/json" "fmt" + "os" "github.com/spf13/cobra" @@ -71,7 +73,33 @@ contain valid denominations. Accounts may optionally be supplied with vesting pa vestingAmtStr, _ := cmd.Flags().GetString(flagVestingAmt) moduleNameStr, _ := cmd.Flags().GetString(flagModuleName) - return genutil.AddGenesisAccount(clientCtx.Codec, clientCtx.AddressCodec, addr, appendflag, config.GenesisFile(), args[1], vestingAmtStr, vestingStart, vestingEnd, moduleNameStr) + addrStr, err := addressCodec.BytesToString(addr) + if err != nil { + return err + } + + coins, err := sdk.ParseCoinsNormalized(args[1]) + if err != nil { + return err + } + + vestingAmt, err := sdk.ParseCoinsNormalized(vestingAmtStr) + if err != nil { + return err + } + + accounts := []genutil.GenesisAccount{ + { + Address: addrStr, + Coins: coins, + VestingAmt: vestingAmt, + VestingStart: vestingStart, + VestingEnd: vestingEnd, + ModuleName: moduleNameStr, + }, + } + + return genutil.AddGenesisAccounts(clientCtx.Codec, clientCtx.AddressCodec, accounts, appendflag, config.GenesisFile()) }, } @@ -85,3 +113,66 @@ contain valid denominations. Accounts may optionally be supplied with vesting pa return cmd } + +// AddBulkGenesisAccountCmd returns bulk-add-genesis-account cobra Command. +// This command is provided as a default, applications are expected to provide their own command if custom genesis accounts are needed. +func AddBulkGenesisAccountCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "bulk-add-genesis-account [/file/path.json]", + Short: "Bulk add genesis accounts to genesis.json", + Example: `bulk-add-genesis-account accounts.json + +where accounts.json is: + +[ + { + "address": "cosmos139f7kncmglres2nf3h4hc4tade85ekfr8sulz5", + "coins": [ + { "denom": "umuon", "amount": "100000000" }, + { "denom": "stake", "amount": "200000000" } + ] + }, + { + "address": "cosmos1e0jnq2sun3dzjh8p2xq95kk0expwmd7shwjpfg", + "coins": [ + { "denom": "umuon", "amount": "500000000" } + ], + "vesting_amt": [ + { "denom": "umuon", "amount": "400000000" } + ], + "vesting_start": 1724711478, + "vesting_end": 1914013878 + } +] +`, + Long: `Add genesis accounts in bulk to genesis.json. The provided account must specify +the account address and a list of initial coins. The list of initial tokens must +contain valid denominations. Accounts may optionally be supplied with vesting parameters. +`, + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + clientCtx := client.GetClientContextFromCmd(cmd) + config := client.GetConfigFromCmd(cmd) + + f, err := os.Open(args[0]) + if err != nil { + return fmt.Errorf("failed to open file: %w", err) + } + defer f.Close() + + var accounts []genutil.GenesisAccount + if err := json.NewDecoder(f).Decode(&accounts); err != nil { + return fmt.Errorf("failed to decode JSON: %w", err) + } + + appendflag, _ := cmd.Flags().GetBool(flagAppendMode) + + return genutil.AddGenesisAccounts(clientCtx.Codec, clientCtx.AddressCodec, accounts, appendflag, config.GenesisFile()) + }, + } + + cmd.Flags().Bool(flagAppendMode, false, "append the coins to an account already in the genesis.json file") + flags.AddQueryFlagsToCmd(cmd) + + return cmd +} diff --git a/x/genutil/client/cli/genaccount_test.go b/x/genutil/client/cli/genaccount_test.go index c0b293cb43b3..c75894f3a2fc 100644 --- a/x/genutil/client/cli/genaccount_test.go +++ b/x/genutil/client/cli/genaccount_test.go @@ -2,6 +2,9 @@ package cli_test import ( "context" + "encoding/json" + "os" + "path" "testing" "github.com/spf13/viper" @@ -9,6 +12,7 @@ import ( corectx "cosmossdk.io/core/context" "cosmossdk.io/log" + banktypes "cosmossdk.io/x/bank/types" "github.com/cosmos/cosmos-sdk/client" codectestutil "github.com/cosmos/cosmos-sdk/codec/testutil" @@ -18,8 +22,10 @@ import ( sdk "github.com/cosmos/cosmos-sdk/types" moduletestutil "github.com/cosmos/cosmos-sdk/types/module/testutil" "github.com/cosmos/cosmos-sdk/x/auth" + "github.com/cosmos/cosmos-sdk/x/genutil" genutilcli "github.com/cosmos/cosmos-sdk/x/genutil/client/cli" genutiltest "github.com/cosmos/cosmos-sdk/x/genutil/client/testutil" + genutiltypes "github.com/cosmos/cosmos-sdk/x/genutil/types" ) func TestAddGenesisAccountCmd(t *testing.T) { @@ -111,3 +117,166 @@ func TestAddGenesisAccountCmd(t *testing.T) { }) } } + +func TestBulkAddGenesisAccountCmd(t *testing.T) { + ac := codectestutil.CodecOptions{}.GetAddressCodec() + _, _, addr1 := testdata.KeyTestPubAddr() + _, _, addr2 := testdata.KeyTestPubAddr() + _, _, addr3 := testdata.KeyTestPubAddr() + addr1Str, err := ac.BytesToString(addr1) + require.NoError(t, err) + addr2Str, err := ac.BytesToString(addr2) + require.NoError(t, err) + addr3Str, err := ac.BytesToString(addr3) + require.NoError(t, err) + + tests := []struct { + name string + state [][]genutil.GenesisAccount + expected map[string]sdk.Coins + appendFlag bool + expectErr bool + }{ + { + name: "invalid address", + state: [][]genutil.GenesisAccount{ + { + { + Address: "invalid", + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + }, + }, + expectErr: true, + }, + { + name: "no append flag for multiple account adds", + state: [][]genutil.GenesisAccount{ + { + { + Address: addr1Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + }, + { + { + Address: addr1Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 2)), + }, + }, + }, + appendFlag: false, + expectErr: true, + }, + + { + name: "multiple additions with append", + state: [][]genutil.GenesisAccount{ + { + { + Address: addr1Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + { + Address: addr2Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + }, + { + { + Address: addr1Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 2)), + }, + { + Address: addr2Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("stake", 1)), + }, + { + Address: addr3Str, + Coins: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + }, + }, + expected: map[string]sdk.Coins{ + addr1Str: sdk.NewCoins(sdk.NewInt64Coin("test", 3)), + addr2Str: sdk.NewCoins(sdk.NewInt64Coin("test", 1), sdk.NewInt64Coin("stake", 1)), + addr3Str: sdk.NewCoins(sdk.NewInt64Coin("test", 1)), + }, + appendFlag: true, + expectErr: false, + }, + } + + for _, tc := range tests { + tc := tc + t.Run(tc.name, func(t *testing.T) { + home := t.TempDir() + logger := log.NewNopLogger() + v := viper.New() + + encodingConfig := moduletestutil.MakeTestEncodingConfig(codectestutil.CodecOptions{}, auth.AppModule{}) + appCodec := encodingConfig.Codec + txConfig := encodingConfig.TxConfig + err = genutiltest.ExecInitCmd(testMbm, home, appCodec) + require.NoError(t, err) + + err = writeAndTrackDefaultConfig(v, home) + require.NoError(t, err) + clientCtx := client.Context{}.WithCodec(appCodec).WithHomeDir(home). + WithAddressCodec(ac).WithTxConfig(txConfig) + + ctx := context.Background() + ctx = context.WithValue(ctx, client.ClientContextKey, &clientCtx) + ctx = context.WithValue(ctx, corectx.ViperContextKey, v) + ctx = context.WithValue(ctx, corectx.LoggerContextKey, logger) + + // The first iteration (pre-append) may not error. + // Check if any errors after all state transitions to genesis. + doesErr := false + + // apply multiple state iterations if applicable (e.g. --append) + for _, state := range tc.state { + bz, err := json.Marshal(state) + require.NoError(t, err) + + filePath := path.Join(home, "accounts.json") + err = os.WriteFile(filePath, bz, 0o600) + require.NoError(t, err) + + cmd := genutilcli.AddBulkGenesisAccountCmd() + args := []string{filePath} + if tc.appendFlag { + args = append(args, "--append") + } + cmd.SetArgs(args) + + err = cmd.ExecuteContext(ctx) + if err != nil { + doesErr = true + } + } + require.Equal(t, tc.expectErr, doesErr) + + // an error already occurred, no need to check the state + if doesErr { + return + } + + appState, _, err := genutiltypes.GenesisStateFromGenFile(path.Join(home, "config", "genesis.json")) + require.NoError(t, err) + + bankState := banktypes.GetGenesisStateFromAppState(encodingConfig.Codec, appState) + + require.EqualValues(t, len(tc.expected), len(bankState.Balances)) + for _, acc := range bankState.Balances { + require.True(t, tc.expected[acc.Address].Equal(acc.Coins), "expected: %v, got: %v", tc.expected[acc.Address], acc.Coins) + } + + expectedSupply := sdk.NewCoins() + for _, coins := range tc.expected { + expectedSupply = expectedSupply.Add(coins...) + } + require.Equal(t, expectedSupply, bankState.Supply) + }) + } +} diff --git a/x/genutil/genaccounts.go b/x/genutil/genaccounts.go index 75899c8cfd89..d3472fb792f6 100644 --- a/x/genutil/genaccounts.go +++ b/x/genutil/genaccounts.go @@ -15,133 +15,148 @@ import ( genutiltypes "github.com/cosmos/cosmos-sdk/x/genutil/types" ) -// AddGenesisAccount adds a genesis account to the genesis state. -// Where `cdc` is client codec, `genesisFileUrl` is the path/url of current genesis file, -// `accAddr` is the address to be added to the genesis state, `amountStr` is the list of initial coins -// to be added for the account, `appendAcct` updates the account if already exists. -// `vestingStart, vestingEnd and vestingAmtStr` respectively are the schedule start time, end time (unix epoch) -// `moduleName` is the module name for which the account is being created -// and coins to be appended to the account already in the genesis.json file. -func AddGenesisAccount( +type GenesisAccount struct { + // Base + Address string `json:"address"` + Coins sdk.Coins `json:"coins"` + + // Vesting + VestingAmt sdk.Coins `json:"vesting_amt,omitempty"` + VestingStart int64 `json:"vesting_start,omitempty"` + VestingEnd int64 `json:"vesting_end,omitempty"` + + // Module + ModuleName string `json:"module_name,omitempty"` +} + +// AddGenesisAccounts adds genesis accounts to the genesis state. +// Where `cdc` is the client codec, `addressCodec` is the address codec, `accounts` are the genesis accounts to add, +// `appendAcct` updates the account if already exists, and `genesisFileURL` is the path/url of the current genesis file. +func AddGenesisAccounts( cdc codec.Codec, addressCodec address.Codec, - accAddr sdk.AccAddress, + accounts []GenesisAccount, appendAcct bool, - genesisFileURL, amountStr, vestingAmtStr string, - vestingStart, vestingEnd int64, - moduleName string, + genesisFileURL string, ) error { - addr, err := addressCodec.BytesToString(accAddr) + appState, appGenesis, err := genutiltypes.GenesisStateFromGenFile(genesisFileURL) if err != nil { - return err + return fmt.Errorf("failed to unmarshal genesis state: %w", err) } - coins, err := sdk.ParseCoinsNormalized(amountStr) - if err != nil { - return fmt.Errorf("failed to parse coins: %w", err) - } + authGenState := authtypes.GetGenesisStateFromAppState(cdc, appState) + bankGenState := banktypes.GetGenesisStateFromAppState(cdc, appState) - vestingAmt, err := sdk.ParseCoinsNormalized(vestingAmtStr) + accs, err := authtypes.UnpackAccounts(authGenState.Accounts) if err != nil { - return fmt.Errorf("failed to parse vesting amount: %w", err) + return fmt.Errorf("failed to get accounts from any: %w", err) } - // create concrete account type based on input parameters - var genAccount authtypes.GenesisAccount + newSupplyCoinsCache := sdk.NewCoins() + balanceCache := make(map[string]banktypes.Balance) + for _, acc := range accs { + for _, balance := range bankGenState.GetBalances() { + if balance.Address == acc.GetAddress().String() { + balanceCache[acc.GetAddress().String()] = balance + } + } + } - balances := banktypes.Balance{Address: addr, Coins: coins.Sort()} - baseAccount := authtypes.NewBaseAccount(accAddr, nil, 0, 0) + for _, acc := range accounts { + addr := acc.Address + coins := acc.Coins - if !vestingAmt.IsZero() { - baseVestingAccount, err := authvesting.NewBaseVestingAccount(baseAccount, vestingAmt.Sort(), vestingEnd) + accAddr, err := addressCodec.StringToBytes(addr) if err != nil { - return fmt.Errorf("failed to create base vesting account: %w", err) + return fmt.Errorf("failed to parse account address %s: %w", addr, err) } - if (balances.Coins.IsZero() && !baseVestingAccount.OriginalVesting.IsZero()) || - baseVestingAccount.OriginalVesting.IsAnyGT(balances.Coins) { - return errors.New("vesting amount cannot be greater than total amount") - } + // create concrete account type based on input parameters + var genAccount authtypes.GenesisAccount - switch { - case vestingStart != 0 && vestingEnd != 0: - genAccount = authvesting.NewContinuousVestingAccountRaw(baseVestingAccount, vestingStart) + balances := banktypes.Balance{Address: addr, Coins: coins.Sort()} + baseAccount := authtypes.NewBaseAccount(accAddr, nil, 0, 0) - case vestingEnd != 0: - genAccount = authvesting.NewDelayedVestingAccountRaw(baseVestingAccount) + vestingAmt := acc.VestingAmt + if !vestingAmt.IsZero() { + vestingStart := acc.VestingStart + vestingEnd := acc.VestingEnd - default: - return errors.New("invalid vesting parameters; must supply start and end time or end time") - } - } else if moduleName != "" { - genAccount = authtypes.NewEmptyModuleAccount(moduleName, authtypes.Burner, authtypes.Minter) - } else { - genAccount = baseAccount - } + baseVestingAccount, err := authvesting.NewBaseVestingAccount(baseAccount, vestingAmt.Sort(), vestingEnd) + if err != nil { + return fmt.Errorf("failed to create base vesting account: %w", err) + } - if err := genAccount.Validate(); err != nil { - return fmt.Errorf("failed to validate new genesis account: %w", err) - } + if (balances.Coins.IsZero() && !baseVestingAccount.OriginalVesting.IsZero()) || + baseVestingAccount.OriginalVesting.IsAnyGT(balances.Coins) { + return errors.New("vesting amount cannot be greater than total amount") + } - appState, appGenesis, err := genutiltypes.GenesisStateFromGenFile(genesisFileURL) - if err != nil { - return fmt.Errorf("failed to unmarshal genesis state: %w", err) - } + switch { + case vestingStart != 0 && vestingEnd != 0: + genAccount = authvesting.NewContinuousVestingAccountRaw(baseVestingAccount, vestingStart) - authGenState := authtypes.GetGenesisStateFromAppState(cdc, appState) + case vestingEnd != 0: + genAccount = authvesting.NewDelayedVestingAccountRaw(baseVestingAccount) - accs, err := authtypes.UnpackAccounts(authGenState.Accounts) - if err != nil { - return fmt.Errorf("failed to get accounts from any: %w", err) - } + default: + return errors.New("invalid vesting parameters; must supply start and end time or end time") + } + } else if acc.ModuleName != "" { + genAccount = authtypes.NewEmptyModuleAccount(acc.ModuleName, authtypes.Burner, authtypes.Minter) + } else { + genAccount = baseAccount + } - bankGenState := banktypes.GetGenesisStateFromAppState(cdc, appState) - if accs.Contains(accAddr) { - if !appendAcct { - return fmt.Errorf(" Account %s already exists\nUse `append` flag to append account at existing address", accAddr) + if err := genAccount.Validate(); err != nil { + return fmt.Errorf("failed to validate new genesis account: %w", err) } - genesisB := banktypes.GetGenesisStateFromAppState(cdc, appState) - for idx, acc := range genesisB.Balances { - if acc.Address != addr { - continue + if _, ok := balanceCache[addr]; ok { + if !appendAcct { + return fmt.Errorf(" Account %s already exists\nUse `append` flag to append account at existing address", accAddr) } - updatedCoins := acc.Coins.Add(coins...) - bankGenState.Balances[idx] = banktypes.Balance{Address: addr, Coins: updatedCoins.Sort()} - break - } - } else { - // Add the new account to the set of genesis accounts and sanitize the accounts afterwards. - accs = append(accs, genAccount) - accs = authtypes.SanitizeGenesisAccounts(accs) + for idx, acc := range bankGenState.Balances { + if acc.Address != addr { + continue + } - genAccs, err := authtypes.PackAccounts(accs) - if err != nil { - return fmt.Errorf("failed to convert accounts into any's: %w", err) + updatedCoins := acc.Coins.Add(coins...) + bankGenState.Balances[idx] = banktypes.Balance{Address: addr, Coins: updatedCoins.Sort()} + break + } + } else { + accs = append(accs, genAccount) + bankGenState.Balances = append(bankGenState.Balances, balances) } - authGenState.Accounts = genAccs - authGenStateBz, err := cdc.MarshalJSON(&authGenState) - if err != nil { - return fmt.Errorf("failed to marshal auth genesis state: %w", err) - } - appState[authtypes.ModuleName] = authGenStateBz + newSupplyCoinsCache = newSupplyCoinsCache.Add(coins...) + } + + accs = authtypes.SanitizeGenesisAccounts(accs) - bankGenState.Balances = append(bankGenState.Balances, balances) + authGenState.Accounts, err = authtypes.PackAccounts(accs) + if err != nil { + return fmt.Errorf("failed to convert accounts into any's: %w", err) + } + + appState[authtypes.ModuleName], err = cdc.MarshalJSON(&authGenState) + if err != nil { + return fmt.Errorf("failed to marshal auth genesis state: %w", err) } bankGenState.Balances, err = banktypes.SanitizeGenesisBalances(bankGenState.Balances, addressCodec) if err != nil { - return fmt.Errorf("failed to sanitize genesis balance: %w", err) + return fmt.Errorf("failed to sanitize genesis bank Balances: %w", err) } - bankGenState.Supply = bankGenState.Supply.Add(balances.Coins...) - bankGenStateBz, err := cdc.MarshalJSON(bankGenState) + bankGenState.Supply = bankGenState.Supply.Add(newSupplyCoinsCache...) + + appState[banktypes.ModuleName], err = cdc.MarshalJSON(bankGenState) if err != nil { return fmt.Errorf("failed to marshal bank genesis state: %w", err) } - appState[banktypes.ModuleName] = bankGenStateBz appStateJSON, err := json.Marshal(appState) if err != nil { From 292d7b49c3ba8e7293486c1bb829360f1eddbb5e Mon Sep 17 00:00:00 2001 From: Aaron Craelius Date: Wed, 4 Sep 2024 08:06:49 -0400 Subject: [PATCH 4/4] feat(indexer/postgres): add insert/update/delete functionality (#21186) --- indexer/postgres/delete.go | 61 +++++ indexer/postgres/indexer.go | 2 + indexer/postgres/insert_update.go | 116 +++++++++ indexer/postgres/listener.go | 28 +++ indexer/postgres/options.go | 8 +- indexer/postgres/params.go | 116 +++++++++ indexer/postgres/select.go | 299 ++++++++++++++++++++++++ indexer/postgres/tests/go.mod | 6 + indexer/postgres/tests/go.sum | 6 + indexer/postgres/tests/postgres_test.go | 99 ++++++++ indexer/postgres/view.go | 151 ++++++++++++ indexer/postgres/where.go | 60 +++++ 12 files changed, 951 insertions(+), 1 deletion(-) create mode 100644 indexer/postgres/delete.go create mode 100644 indexer/postgres/insert_update.go create mode 100644 indexer/postgres/params.go create mode 100644 indexer/postgres/select.go create mode 100644 indexer/postgres/tests/postgres_test.go create mode 100644 indexer/postgres/view.go create mode 100644 indexer/postgres/where.go diff --git a/indexer/postgres/delete.go b/indexer/postgres/delete.go new file mode 100644 index 000000000000..08bdb155dc62 --- /dev/null +++ b/indexer/postgres/delete.go @@ -0,0 +1,61 @@ +package postgres + +import ( + "context" + "fmt" + "io" + "strings" +) + +// delete deletes the row with the provided key from the table. +func (tm *objectIndexer) delete(ctx context.Context, conn dbConn, key interface{}) error { + buf := new(strings.Builder) + var params []interface{} + var err error + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + params, err = tm.retainDeleteSqlAndParams(buf, key) + } else { + params, err = tm.deleteSqlAndParams(buf, key) + } + if err != nil { + return err + } + + sqlStr := buf.String() + tm.options.logger.Info("Delete", "sql", sqlStr, "params", params) + _, err = conn.ExecContext(ctx, sqlStr, params...) + return err +} + +// deleteSqlAndParams generates a DELETE statement and binding parameters for the provided key. +func (tm *objectIndexer) deleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "DELETE FROM %q", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +// retainDeleteSqlAndParams generates an UPDATE statement to set the _deleted column to true for the provided key +// which is used when the table is set to retain deletions mode. +func (tm *objectIndexer) retainDeleteSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "UPDATE %q SET _deleted = TRUE", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} diff --git a/indexer/postgres/indexer.go b/indexer/postgres/indexer.go index 2c37e9a79b11..bfaac25842e5 100644 --- a/indexer/postgres/indexer.go +++ b/indexer/postgres/indexer.go @@ -72,6 +72,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) { opts := options{ disableRetainDeletions: config.DisableRetainDeletions, logger: params.Logger, + addressCodec: params.AddressCodec, } idx := &indexerImpl{ @@ -85,6 +86,7 @@ func StartIndexer(params indexer.InitParams) (indexer.InitResult, error) { return indexer.InitResult{ Listener: idx.listener(), + View: idx, }, nil } diff --git a/indexer/postgres/insert_update.go b/indexer/postgres/insert_update.go new file mode 100644 index 000000000000..fb246a84b61c --- /dev/null +++ b/indexer/postgres/insert_update.go @@ -0,0 +1,116 @@ +package postgres + +import ( + "context" + "fmt" + "io" + "strings" +) + +// insertUpdate inserts or updates the row with the provided key and value. +func (tm *objectIndexer) insertUpdate(ctx context.Context, conn dbConn, key, value interface{}) error { + exists, err := tm.exists(ctx, conn, key) + if err != nil { + return err + } + + buf := new(strings.Builder) + var params []interface{} + if exists { + if len(tm.typ.ValueFields) == 0 { + // special case where there are no value fields, so we can't update anything + return nil + } + + params, err = tm.updateSql(buf, key, value) + } else { + params, err = tm.insertSql(buf, key, value) + } + if err != nil { + return err + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Insert or Update", "sql", sqlStr, "params", params) + } + _, err = conn.ExecContext(ctx, sqlStr, params...) + return err +} + +// insertSql generates an INSERT statement and binding parameters for the provided key and value. +func (tm *objectIndexer) insertSql(w io.Writer, key, value interface{}) ([]interface{}, error) { + keyParams, keyCols, err := tm.bindKeyParams(key) + if err != nil { + return nil, err + } + + valueParams, valueCols, err := tm.bindValueParams(value) + if err != nil { + return nil, err + } + + var allParams []interface{} + allParams = append(allParams, keyParams...) + allParams = append(allParams, valueParams...) + + allCols := make([]string, 0, len(keyCols)+len(valueCols)) + allCols = append(allCols, keyCols...) + allCols = append(allCols, valueCols...) + + var paramBindings []string + for i := 1; i <= len(allCols); i++ { + paramBindings = append(paramBindings, fmt.Sprintf("$%d", i)) + } + + _, err = fmt.Fprintf(w, "INSERT INTO %q (%s) VALUES (%s);", tm.tableName(), + strings.Join(allCols, ", "), + strings.Join(paramBindings, ", "), + ) + return allParams, err +} + +// updateSql generates an UPDATE statement and binding parameters for the provided key and value. +func (tm *objectIndexer) updateSql(w io.Writer, key, value interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "UPDATE %q SET ", tm.tableName()) + if err != nil { + return nil, err + } + + valueParams, valueCols, err := tm.bindValueParams(value) + if err != nil { + return nil, err + } + + paramIdx := 1 + for i, col := range valueCols { + if i > 0 { + _, err = fmt.Fprintf(w, ", ") + if err != nil { + return nil, err + } + } + _, err = fmt.Fprintf(w, "%s = $%d", col, paramIdx) + if err != nil { + return nil, err + } + + paramIdx++ + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + _, err = fmt.Fprintf(w, ", _deleted = FALSE") + if err != nil { + return nil, err + } + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, paramIdx) + if err != nil { + return nil, err + } + + allParams := append(valueParams, keyParams...) + _, err = fmt.Fprintf(w, ";") + return allParams, err +} diff --git a/indexer/postgres/listener.go b/indexer/postgres/listener.go index 1f46c1c7c55c..21d08a736af7 100644 --- a/indexer/postgres/listener.go +++ b/indexer/postgres/listener.go @@ -25,6 +25,34 @@ func (i *indexerImpl) listener() appdata.Listener { _, err := i.tx.Exec("INSERT INTO block (number) VALUES ($1)", data.Height) return err }, + OnObjectUpdate: func(data appdata.ObjectUpdateData) error { + module := data.ModuleName + mod, ok := i.modules[module] + if !ok { + return fmt.Errorf("module %s not initialized", module) + } + + for _, update := range data.Updates { + if i.logger != nil { + i.logger.Debug("OnObjectUpdate", "module", module, "type", update.TypeName, "key", update.Key, "delete", update.Delete, "value", update.Value) + } + tm, ok := mod.tables[update.TypeName] + if !ok { + return fmt.Errorf("object type %s not found in schema for module %s", update.TypeName, module) + } + + var err error + if update.Delete { + err = tm.delete(i.ctx, i.tx, update.Key) + } else { + err = tm.insertUpdate(i.ctx, i.tx, update.Key, update.Value) + } + if err != nil { + return err + } + } + return nil + }, Commit: func(data appdata.CommitData) (func() error, error) { err := i.tx.Commit() if err != nil { diff --git a/indexer/postgres/options.go b/indexer/postgres/options.go index d18a4c4d7f2c..db905f9dbaa7 100644 --- a/indexer/postgres/options.go +++ b/indexer/postgres/options.go @@ -1,6 +1,9 @@ package postgres -import "cosmossdk.io/schema/logutil" +import ( + "cosmossdk.io/schema/addressutil" + "cosmossdk.io/schema/logutil" +) // options are the options for module and object indexers. type options struct { @@ -9,4 +12,7 @@ type options struct { // logger is the logger for the indexer to use. It may be nil. logger logutil.Logger + + // addressCodec is the codec for encoding and decoding addresses. It is expected to be non-nil. + addressCodec addressutil.AddressCodec } diff --git a/indexer/postgres/params.go b/indexer/postgres/params.go new file mode 100644 index 000000000000..b2af8f6f174a --- /dev/null +++ b/indexer/postgres/params.go @@ -0,0 +1,116 @@ +package postgres + +import ( + "fmt" + "time" + + "cosmossdk.io/schema" +) + +// bindKeyParams binds the key to the key columns. +func (tm *objectIndexer) bindKeyParams(key interface{}) ([]interface{}, []string, error) { + n := len(tm.typ.KeyFields) + if n == 0 { + // singleton, set _id = 1 + return []interface{}{1}, []string{"_id"}, nil + } else if n == 1 { + return tm.bindParams(tm.typ.KeyFields, []interface{}{key}) + } else { + key, ok := key.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("expected key to be a slice") + } + + return tm.bindParams(tm.typ.KeyFields, key) + } +} + +func (tm *objectIndexer) bindValueParams(value interface{}) (params []interface{}, valueCols []string, err error) { + n := len(tm.typ.ValueFields) + if n == 0 { + return nil, nil, nil + } else if valueUpdates, ok := value.(schema.ValueUpdates); ok { + var e error + var fields []schema.Field + var params []interface{} + if err := valueUpdates.Iterate(func(name string, value interface{}) bool { + field, ok := tm.valueFields[name] + if !ok { + e = fmt.Errorf("unknown column %q", name) + return false + } + fields = append(fields, field) + params = append(params, value) + return true + }); err != nil { + return nil, nil, err + } + if e != nil { + return nil, nil, e + } + + return tm.bindParams(fields, params) + } else if n == 1 { + return tm.bindParams(tm.typ.ValueFields, []interface{}{value}) + } else { + values, ok := value.([]interface{}) + if !ok { + return nil, nil, fmt.Errorf("expected values to be a slice") + } + + return tm.bindParams(tm.typ.ValueFields, values) + } +} + +func (tm *objectIndexer) bindParams(fields []schema.Field, values []interface{}) ([]interface{}, []string, error) { + names := make([]string, 0, len(fields)) + params := make([]interface{}, 0, len(fields)) + for i, field := range fields { + if i >= len(values) { + return nil, nil, fmt.Errorf("missing value for field %q", field.Name) + } + + param, err := tm.bindParam(field, values[i]) + if err != nil { + return nil, nil, err + } + + name, err := tm.updatableColumnName(field) + if err != nil { + return nil, nil, err + } + + names = append(names, name) + params = append(params, param) + } + return params, names, nil +} + +func (tm *objectIndexer) bindParam(field schema.Field, value interface{}) (param interface{}, err error) { + param = value + if value == nil { + if !field.Nullable { + return nil, fmt.Errorf("expected non-null value for field %q", field.Name) + } + } else if field.Kind == schema.TimeKind { + t, ok := value.(time.Time) + if !ok { + return nil, fmt.Errorf("expected time.Time value for field %q, got %T", field.Name, value) + } + + param = t.UnixNano() + } else if field.Kind == schema.DurationKind { + t, ok := value.(time.Duration) + if !ok { + return nil, fmt.Errorf("expected time.Duration value for field %q, got %T", field.Name, value) + } + + param = int64(t) + } else if field.Kind == schema.AddressKind { + param, err = tm.options.addressCodec.BytesToString(value.([]byte)) + if err != nil { + return nil, fmt.Errorf("address encoding failed for field %q: %w", field.Name, err) + } + } + return +} diff --git a/indexer/postgres/select.go b/indexer/postgres/select.go new file mode 100644 index 000000000000..46ef12d3f15c --- /dev/null +++ b/indexer/postgres/select.go @@ -0,0 +1,299 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/json" + "errors" + "fmt" + "io" + "strconv" + "strings" + "time" + + "cosmossdk.io/schema" +) + +// Count returns the number of rows in the table. +func (tm *objectIndexer) count(ctx context.Context, conn dbConn) (int, error) { + sqlStr := fmt.Sprintf("SELECT COUNT(*) FROM %q;", tm.tableName()) + if tm.options.logger != nil { + tm.options.logger.Debug("Count", "sql", sqlStr) + } + row := conn.QueryRowContext(ctx, sqlStr) + var count int + err := row.Scan(&count) + return count, err +} + +// exists checks if a row with the provided key exists in the table. +func (tm *objectIndexer) exists(ctx context.Context, conn dbConn, key interface{}) (bool, error) { + buf := new(strings.Builder) + params, err := tm.existsSqlAndParams(buf, key) + if err != nil { + return false, err + } + + return tm.checkExists(ctx, conn, buf.String(), params) +} + +// checkExists checks if a row exists in the table. +func (tm *objectIndexer) checkExists(ctx context.Context, conn dbConn, sqlStr string, params []interface{}) (bool, error) { + if tm.options.logger != nil { + tm.options.logger.Debug("Check exists", "sql", sqlStr, "params", params) + } + var res interface{} + err := conn.QueryRowContext(ctx, sqlStr, params...).Scan(&res) + switch err { + case nil: + return true, nil + case sql.ErrNoRows: + return false, nil + default: + return false, err + } +} + +// existsSqlAndParams generates a SELECT statement to check if a row with the provided key exists in the table. +func (tm *objectIndexer) existsSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + _, err := fmt.Fprintf(w, "SELECT 1 FROM %q", tm.tableName()) + if err != nil { + return nil, err + } + + _, keyParams, err := tm.whereSqlAndParams(w, key, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +func (tm *objectIndexer) get(ctx context.Context, conn dbConn, key interface{}) (schema.ObjectUpdate, bool, error) { + buf := new(strings.Builder) + params, err := tm.getSqlAndParams(buf, key) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Get", "sql", sqlStr, "params", params) + } + + row := conn.QueryRowContext(ctx, sqlStr, params...) + return tm.readRow(row) +} + +func (tm *objectIndexer) selectAllSql(w io.Writer) error { + err := tm.selectAllClause(w) + if err != nil { + return err + } + + _, err = fmt.Fprintf(w, ";") + return err +} + +func (tm *objectIndexer) getSqlAndParams(w io.Writer, key interface{}) ([]interface{}, error) { + err := tm.selectAllClause(w) + if err != nil { + return nil, err + } + + keyParams, keyCols, err := tm.bindKeyParams(key) + if err != nil { + return nil, err + } + + _, keyParams, err = tm.whereSql(w, keyParams, keyCols, 1) + if err != nil { + return nil, err + } + + _, err = fmt.Fprintf(w, ";") + return keyParams, err +} + +func (tm *objectIndexer) selectAllClause(w io.Writer) error { + allFields := make([]string, 0, len(tm.typ.KeyFields)+len(tm.typ.ValueFields)) + + for _, field := range tm.typ.KeyFields { + colName, err := tm.updatableColumnName(field) + if err != nil { + return err + } + allFields = append(allFields, colName) + } + + for _, field := range tm.typ.ValueFields { + colName, err := tm.updatableColumnName(field) + if err != nil { + return err + } + allFields = append(allFields, colName) + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + allFields = append(allFields, "_deleted") + } + + _, err := fmt.Fprintf(w, "SELECT %s FROM %q", strings.Join(allFields, ", "), tm.tableName()) + if err != nil { + return err + } + + return nil +} + +func (tm *objectIndexer) readRow(row interface{ Scan(...interface{}) error }) (schema.ObjectUpdate, bool, error) { + var res []interface{} + for _, f := range tm.typ.KeyFields { + res = append(res, tm.colBindValue(f)) + } + + for _, f := range tm.typ.ValueFields { + res = append(res, tm.colBindValue(f)) + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + res = append(res, new(bool)) + } + + err := row.Scan(res...) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return schema.ObjectUpdate{}, false, err + } + return schema.ObjectUpdate{}, false, err + } + + var keys []interface{} + for _, field := range tm.typ.KeyFields { + x, err := tm.readCol(field, res[0]) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + keys = append(keys, x) + res = res[1:] + } + + var key interface{} = keys + if len(keys) == 1 { + key = keys[0] + } + + var values []interface{} + for _, field := range tm.typ.ValueFields { + x, err := tm.readCol(field, res[0]) + if err != nil { + return schema.ObjectUpdate{}, false, err + } + values = append(values, x) + res = res[1:] + } + + var value interface{} = values + if len(values) == 1 { + value = values[0] + } + + update := schema.ObjectUpdate{ + TypeName: tm.typ.Name, + Key: key, + Value: value, + } + + if !tm.options.disableRetainDeletions && tm.typ.RetainDeletions { + deleted := res[0].(*bool) + if *deleted { + update.Delete = true + } + } + + return update, true, nil +} + +func (tm *objectIndexer) colBindValue(field schema.Field) interface{} { + switch field.Kind { + case schema.BytesKind: + return new(interface{}) + default: + return new(sql.NullString) + } +} + +func (tm *objectIndexer) readCol(field schema.Field, value interface{}) (interface{}, error) { + switch field.Kind { + case schema.BytesKind: + // for bytes types we either get []byte or nil + value = *value.(*interface{}) + return value, nil + default: + } + + nullStr := *value.(*sql.NullString) + if field.Nullable { + if !nullStr.Valid { + return nil, nil + } + } + str := nullStr.String + + switch field.Kind { + case schema.StringKind, schema.EnumKind, schema.IntegerStringKind, schema.DecimalStringKind: + return str, nil + case schema.Uint8Kind: + value, err := strconv.ParseUint(str, 10, 8) + return uint8(value), err + case schema.Uint16Kind: + value, err := strconv.ParseUint(str, 10, 16) + return uint16(value), err + case schema.Uint32Kind: + value, err := strconv.ParseUint(str, 10, 32) + return uint32(value), err + case schema.Uint64Kind: + value, err := strconv.ParseUint(str, 10, 64) + return value, err + case schema.Int8Kind: + value, err := strconv.ParseInt(str, 10, 8) + return int8(value), err + case schema.Int16Kind: + value, err := strconv.ParseInt(str, 10, 16) + return int16(value), err + case schema.Int32Kind: + value, err := strconv.ParseInt(str, 10, 32) + return int32(value), err + case schema.Int64Kind: + value, err := strconv.ParseInt(str, 10, 64) + return value, err + case schema.Float32Kind: + value, err := strconv.ParseFloat(str, 32) + return float32(value), err + case schema.Float64Kind: + value, err := strconv.ParseFloat(str, 64) + return value, err + case schema.BoolKind: + value, err := strconv.ParseBool(str) + return value, err + case schema.JSONKind: + return json.RawMessage(str), nil + case schema.TimeKind: + value, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return nil, err + } + return time.Unix(0, value), nil + case schema.DurationKind: + value, err := strconv.ParseInt(str, 10, 64) + if err != nil { + return nil, err + } + return time.Duration(value), nil + case schema.AddressKind: + return tm.options.addressCodec.StringToBytes(str) + default: + return value, nil + } +} diff --git a/indexer/postgres/tests/go.mod b/indexer/postgres/tests/go.mod index a72a1dc7fc07..cb642a3e7c6f 100644 --- a/indexer/postgres/tests/go.mod +++ b/indexer/postgres/tests/go.mod @@ -5,6 +5,7 @@ go 1.23 require ( cosmossdk.io/indexer/postgres v0.0.0-00010101000000-000000000000 cosmossdk.io/schema v0.1.1 + cosmossdk.io/schema/testing v0.0.0 github.com/fergusstrange/embedded-postgres v1.29.0 github.com/hashicorp/consul/sdk v0.16.1 github.com/jackc/pgx/v5 v5.6.0 @@ -13,6 +14,7 @@ require ( ) require ( + github.com/cockroachdb/apd/v3 v3.2.1 // indirect github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc // indirect github.com/google/go-cmp v0.6.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect @@ -22,14 +24,18 @@ require ( github.com/lib/pq v1.10.9 // indirect github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/rogpeppe/go-internal v1.12.0 // indirect + github.com/tidwall/btree v1.7.0 // indirect github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect golang.org/x/crypto v0.26.0 // indirect golang.org/x/sync v0.8.0 // indirect golang.org/x/sys v0.23.0 // indirect golang.org/x/text v0.17.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + pgregory.net/rapid v1.1.0 // indirect ) replace cosmossdk.io/indexer/postgres => ../. replace cosmossdk.io/schema => ../../../schema + +replace cosmossdk.io/schema/testing => ../../../schema/testing diff --git a/indexer/postgres/tests/go.sum b/indexer/postgres/tests/go.sum index 809d5040b848..f310c988deaa 100644 --- a/indexer/postgres/tests/go.sum +++ b/indexer/postgres/tests/go.sum @@ -1,3 +1,5 @@ +github.com/cockroachdb/apd/v3 v3.2.1 h1:U+8j7t0axsIgvQUqthuNm82HIrYXodOV2iWLWtEaIwg= +github.com/cockroachdb/apd/v3 v3.2.1/go.mod h1:klXJcjp+FffLTHlhIG69tezTDvdP065naDsHzKhYSqc= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -32,6 +34,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +github.com/tidwall/btree v1.7.0 h1:L1fkJH/AuEh5zBnnBbmTwQ5Lt+bRJ5A8EWecslvo9iI= +github.com/tidwall/btree v1.7.0/go.mod h1:twD9XRA5jj9VUQGELzDO4HPQTNJsoWWfYEL+EUQ2cKY= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= @@ -52,3 +56,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= +pgregory.net/rapid v1.1.0 h1:CMa0sjHSru3puNx+J0MIAuiiEV4N0qj8/cMWGBBCsjw= +pgregory.net/rapid v1.1.0/go.mod h1:PY5XlDGj0+V1FCq0o192FdRhpKHGTRIWBgqjDBTrq04= diff --git a/indexer/postgres/tests/postgres_test.go b/indexer/postgres/tests/postgres_test.go new file mode 100644 index 000000000000..fc725f9cc1cf --- /dev/null +++ b/indexer/postgres/tests/postgres_test.go @@ -0,0 +1,99 @@ +package tests + +import ( + "context" + "os" + "strings" + "testing" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/hashicorp/consul/sdk/freeport" + _ "github.com/jackc/pgx/v5/stdlib" + "github.com/stretchr/testify/require" + + "cosmossdk.io/indexer/postgres" + "cosmossdk.io/schema/addressutil" + "cosmossdk.io/schema/indexer" + indexertesting "cosmossdk.io/schema/testing" + "cosmossdk.io/schema/testing/appdatasim" + "cosmossdk.io/schema/testing/statesim" +) + +func TestPostgresIndexer(t *testing.T) { + t.Run("RetainDeletions", func(t *testing.T) { + testPostgresIndexer(t, true) + }) + t.Run("NoRetainDeletions", func(t *testing.T) { + testPostgresIndexer(t, false) + }) +} + +func testPostgresIndexer(t *testing.T, retainDeletions bool) { + tempDir, err := os.MkdirTemp("", "postgres-indexer-test") + require.NoError(t, err) + + dbPort := freeport.GetOne(t) + pgConfig := embeddedpostgres.DefaultConfig(). + Port(uint32(dbPort)). + DataPath(tempDir) + + dbUrl := pgConfig.GetConnectionURL() + pg := embeddedpostgres.NewDatabase(pgConfig) + require.NoError(t, pg.Start()) + + ctx, cancel := context.WithCancel(context.Background()) + + t.Cleanup(func() { + cancel() + require.NoError(t, pg.Stop()) + err := os.RemoveAll(tempDir) + require.NoError(t, err) + }) + + cfg, err := postgresConfigToIndexerConfig(postgres.Config{ + DatabaseURL: dbUrl, + DisableRetainDeletions: !retainDeletions, + }) + require.NoError(t, err) + + debugLog := &strings.Builder{} + + pgIndexer, err := postgres.StartIndexer(indexer.InitParams{ + Config: cfg, + Context: ctx, + Logger: &prettyLogger{debugLog}, + AddressCodec: addressutil.HexAddressCodec{}, + }) + require.NoError(t, err) + + sim, err := appdatasim.NewSimulator(appdatasim.Options{ + Listener: pgIndexer.Listener, + AppSchema: indexertesting.ExampleAppSchema, + StateSimOptions: statesim.Options{ + CanRetainDeletions: retainDeletions, + }, + }) + require.NoError(t, err) + + blockDataGen := sim.BlockDataGenN(10, 100) + numBlocks := 200 + if testing.Short() { + numBlocks = 10 + } + for i := 0; i < numBlocks; i++ { + // using Example generates a deterministic data set based + // on a seed so that regression tests can be created OR rapid.Check can + // be used for fully random property-based testing + blockData := blockDataGen.Example(i) + + // process the generated block data with the simulator which will also + // send it to the indexer + require.NoError(t, sim.ProcessBlockData(blockData), debugLog.String()) + + // compare the expected state in the simulator to the actual state in the indexer and expect the diff to be empty + require.Empty(t, appdatasim.DiffAppData(sim, pgIndexer.View), debugLog.String()) + + // reset the debug log after each successful block so that it doesn't get too long when debugging + debugLog.Reset() + } +} diff --git a/indexer/postgres/view.go b/indexer/postgres/view.go new file mode 100644 index 000000000000..eac2c52f8a8b --- /dev/null +++ b/indexer/postgres/view.go @@ -0,0 +1,151 @@ +package postgres + +import ( + "context" + "database/sql" + "strings" + + "cosmossdk.io/schema" + "cosmossdk.io/schema/view" +) + +var _ view.AppData = &indexerImpl{} + +func (i *indexerImpl) AppState() view.AppState { + return i +} + +func (i *indexerImpl) BlockNum() (uint64, error) { + var blockNum int64 + err := i.tx.QueryRow("SELECT coalesce(max(number), 0) FROM block").Scan(&blockNum) + if err != nil { + return 0, err + } + return uint64(blockNum), nil +} + +type moduleView struct { + moduleIndexer + ctx context.Context + conn dbConn +} + +func (i *indexerImpl) GetModule(moduleName string) (view.ModuleState, error) { + mod, ok := i.modules[moduleName] + if !ok { + return nil, nil + } + return &moduleView{ + moduleIndexer: *mod, + ctx: i.ctx, + conn: i.tx, + }, nil +} + +func (i *indexerImpl) Modules(f func(modState view.ModuleState, err error) bool) { + for _, mod := range i.modules { + if !f(&moduleView{ + moduleIndexer: *mod, + ctx: i.ctx, + conn: i.tx, + }, nil) { + return + } + } +} + +func (i *indexerImpl) NumModules() (int, error) { + return len(i.modules), nil +} + +func (m *moduleView) ModuleName() string { + return m.moduleName +} + +func (m *moduleView) ModuleSchema() schema.ModuleSchema { + return m.schema +} + +func (m *moduleView) GetObjectCollection(objectType string) (view.ObjectCollection, error) { + obj, ok := m.tables[objectType] + if !ok { + return nil, nil + } + return &objectView{ + objectIndexer: *obj, + ctx: m.ctx, + conn: m.conn, + }, nil +} + +func (m *moduleView) ObjectCollections(f func(value view.ObjectCollection, err error) bool) { + for _, obj := range m.tables { + if !f(&objectView{ + objectIndexer: *obj, + ctx: m.ctx, + conn: m.conn, + }, nil) { + return + } + } +} + +func (m *moduleView) NumObjectCollections() (int, error) { + return len(m.tables), nil +} + +type objectView struct { + objectIndexer + ctx context.Context + conn dbConn +} + +func (tm *objectView) ObjectType() schema.ObjectType { + return tm.typ +} + +func (tm *objectView) GetObject(key interface{}) (update schema.ObjectUpdate, found bool, err error) { + return tm.get(tm.ctx, tm.conn, key) +} + +func (tm *objectView) AllState(f func(schema.ObjectUpdate, error) bool) { + buf := new(strings.Builder) + err := tm.selectAllSql(buf) + if err != nil { + panic(err) + } + + sqlStr := buf.String() + if tm.options.logger != nil { + tm.options.logger.Debug("Select", "sql", sqlStr) + } + + rows, err := tm.conn.QueryContext(tm.ctx, sqlStr) + if err != nil { + panic(err) + } + defer func(rows *sql.Rows) { + err := rows.Close() + if err != nil { + panic(err) + } + }(rows) + + for rows.Next() { + update, found, err := tm.readRow(rows) + if err == nil && !found { + err = sql.ErrNoRows + } + if !f(update, err) { + return + } + } +} + +func (tm *objectView) Len() (int, error) { + n, err := tm.count(tm.ctx, tm.conn) + if err != nil { + return 0, err + } + return n, nil +} diff --git a/indexer/postgres/where.go b/indexer/postgres/where.go new file mode 100644 index 000000000000..745092781734 --- /dev/null +++ b/indexer/postgres/where.go @@ -0,0 +1,60 @@ +package postgres + +import ( + "fmt" + "io" +) + +// whereSqlAndParams generates a WHERE clause for the provided key and returns the parameters. +func (tm *objectIndexer) whereSqlAndParams(w io.Writer, key interface{}, startParamIdx int) (endParamIdx int, keyParams []interface{}, err error) { + var keyCols []string + keyParams, keyCols, err = tm.bindKeyParams(key) + if err != nil { + return + } + + endParamIdx, keyParams, err = tm.whereSql(w, keyParams, keyCols, startParamIdx) + return +} + +// whereSql generates a WHERE clause for the provided columns and returns the parameters. +func (tm *objectIndexer) whereSql(w io.Writer, params []interface{}, cols []string, startParamIdx int) (endParamIdx int, resParams []interface{}, err error) { + _, err = fmt.Fprintf(w, " WHERE ") + if err != nil { + return 0, nil, err + } + + endParamIdx = startParamIdx + for i, col := range cols { + if i > 0 { + _, err = fmt.Fprintf(w, " AND ") + if err != nil { + return 0, nil, err + } + } + + _, err = fmt.Fprintf(w, "%s ", col) + if err != nil { + return 0, nil, err + } + + if params[i] == nil { + _, err = fmt.Fprintf(w, "IS NULL") + if err != nil { + return 0, nil, err + } + + } else { + _, err = fmt.Fprintf(w, "= $%d", endParamIdx) + if err != nil { + return 0, nil, err + } + + resParams = append(resParams, params[i]) + + endParamIdx++ + } + } + + return endParamIdx, resParams, nil +}