diff --git a/datanode/sqlstore/amm_pool_test.go b/datanode/sqlstore/amm_pool_test.go index 2099aaff56c..1e811625368 100644 --- a/datanode/sqlstore/amm_pool_test.go +++ b/datanode/sqlstore/amm_pool_test.go @@ -84,7 +84,7 @@ func TestAMMPool_Upsert(t *testing.T) { var upserted entities.AMMPool require.NoError(t, pgxscan.Get( ctx, - connectionSource.Connection, + connectionSource, &upserted, `SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`, partyID, marketID, poolID, ammPartyID)) @@ -121,7 +121,7 @@ func TestAMMPool_Upsert(t *testing.T) { var upserted entities.AMMPool require.NoError(t, pgxscan.Get( ctx, - connectionSource.Connection, + connectionSource, &upserted, `SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`, partyID, marketID, poolID, ammPartyID)) @@ -152,7 +152,7 @@ func TestAMMPool_Upsert(t *testing.T) { var upserted entities.AMMPool require.NoError(t, pgxscan.Get( ctx, - connectionSource.Connection, + connectionSource, &upserted, `SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`, partyID, marketID, poolID, ammPartyID)) diff --git a/datanode/sqlstore/assets_test.go b/datanode/sqlstore/assets_test.go index dd6f71123cd..72afef77975 100644 --- a/datanode/sqlstore/assets_test.go +++ b/datanode/sqlstore/assets_test.go @@ -120,11 +120,18 @@ func TestAssetCache(t *testing.T) { require.NoError(t, err) assert.Equal(t, asset2, fetched) - // Commit the transaction and fetch the asset, we should get the asset with the new symbol + // Commit the sub-transaction and fetch the asset, we should not yet get the asset with the new symbol err = connectionSource.Commit(txCtx) require.NoError(t, err) fetched, err = as.GetByID(ctx, string(asset.ID)) require.NoError(t, err) + assert.Equal(t, asset2, fetched) + + // now commit the main transaction, then we should get the new symbol + err = connectionSource.Commit(ctx) + require.NoError(t, err) + fetched, err = as.GetByID(context.Background(), string(asset.ID)) + require.NoError(t, err) assert.Equal(t, asset3, fetched) } diff --git a/datanode/sqlstore/connection_source.go b/datanode/sqlstore/connection_source.go index c75e41dff67..8ff00ed8626 100644 --- a/datanode/sqlstore/connection_source.go +++ b/datanode/sqlstore/connection_source.go @@ -94,10 +94,6 @@ func registerNumericType(poolConfig *pgxpool.Config) { } } -type delegatingConnection struct { - pool *pgxpool.Pool -} - func CreateConnectionPool(ctx context.Context, conf ConnectionConfig) (*pgxpool.Pool, error) { poolConfig, err := conf.GetPoolConfig() if err != nil { diff --git a/datanode/sqlstore/connection_tx.go b/datanode/sqlstore/connection_tx.go index 5f38d6c30f6..203ed3610c7 100644 --- a/datanode/sqlstore/connection_tx.go +++ b/datanode/sqlstore/connection_tx.go @@ -1,3 +1,18 @@ +// Copyright (C) 2023 Gobalsky Labs Limited +// +// This program is free software: you can redistribute it and/or modify +// it under the terms of the GNU Affero General Public License as +// published by the Free Software Foundation, either version 3 of the +// License, or (at your option) any later version. +// +// This program is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Affero General Public License for more details. +// +// You should have received a copy of the GNU Affero General Public License +// along with this program. If not, see . + package sqlstore import ( diff --git a/datanode/sqlstore/deposits_test.go b/datanode/sqlstore/deposits_test.go index 4e99fb2deea..317d94acf4b 100644 --- a/datanode/sqlstore/deposits_test.go +++ b/datanode/sqlstore/deposits_test.go @@ -76,7 +76,7 @@ func setupDepositStoreTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.Deposits, t.Helper() bs := sqlstore.NewBlocks(connectionSource) ds := sqlstore.NewDeposits(connectionSource) - return bs, ds, connectionSource.Connection + return bs, ds, connectionSource } func testAddDepositForNewBlock(t *testing.T) { diff --git a/datanode/sqlstore/erc20_multisig_added_test.go b/datanode/sqlstore/erc20_multisig_added_test.go index 41afba8c708..fc9d330fa1e 100644 --- a/datanode/sqlstore/erc20_multisig_added_test.go +++ b/datanode/sqlstore/erc20_multisig_added_test.go @@ -40,7 +40,7 @@ func TestERC20MultiSigEvent(t *testing.T) { func setupERC20MultiSigEventStoreTests(t *testing.T) (*sqlstore.ERC20MultiSigSignerEvent, sqlstore.Connection) { t.Helper() ms := sqlstore.NewERC20MultiSigSignerEvent(connectionSource) - return ms, connectionSource.Connection + return ms, connectionSource } func testAddSigner(t *testing.T) { diff --git a/datanode/sqlstore/fees_stats_test.go b/datanode/sqlstore/fees_stats_test.go index aa029c20050..297ffb38de0 100644 --- a/datanode/sqlstore/fees_stats_test.go +++ b/datanode/sqlstore/fees_stats_test.go @@ -84,7 +84,7 @@ func testAddFeesStatsEpochNotExists(t *testing.T) { // Check that the stats were added var got entities.FeesStats - err = pgxscan.Get(ctx, connectionSource.Connection, &got, + err = pgxscan.Get(ctx, connectionSource, &got, `SELECT * FROM fees_stats WHERE market_id = $1 AND asset_id = $2 AND epoch_seq = $3`, market.ID, asset.ID, want.EpochSeq, ) diff --git a/datanode/sqlstore/funding_period_test.go b/datanode/sqlstore/funding_period_test.go index 45d55ba8926..c0db1f388cc 100644 --- a/datanode/sqlstore/funding_period_test.go +++ b/datanode/sqlstore/funding_period_test.go @@ -124,7 +124,7 @@ func testAddFundingPeriodShouldUpdateIfMarketExistsAndSequenceExists(t *testing. require.NoError(t, err) var dbResult entities.FundingPeriod - err = pgxscan.Get(ctx, stores.fp.Connection, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1) + err = pgxscan.Get(ctx, stores.fp, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1) require.NoError(t, err) assert.Equal(t, period, dbResult) @@ -139,7 +139,7 @@ func testAddFundingPeriodShouldUpdateIfMarketExistsAndSequenceExists(t *testing. err = stores.fp.AddFundingPeriod(ctx, &period) require.NoError(t, err) - err = pgxscan.Get(ctx, stores.fp.Connection, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1) + err = pgxscan.Get(ctx, stores.fp, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1) require.NoError(t, err) assert.Equal(t, period, dbResult) } @@ -241,7 +241,7 @@ func testShouldUpdateDataPointInSameBlock(t *testing.T) { require.NoError(t, err) var inserted []entities.FundingPeriodDataPoint - err = pgxscan.Select(ctx, connectionSource.Connection, &inserted, + err = pgxscan.Select(ctx, connectionSource, &inserted, `SELECT * FROM funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3 and vega_time = $4`, stores.markets[0].ID, 1, entities.FundingPeriodDataPointSourceExternal, stores.blocks[4].VegaTime) require.NoError(t, err) @@ -262,7 +262,7 @@ func testShouldUpdateDataPointInSameBlock(t *testing.T) { err = stores.fp.AddDataPoint(ctx, &dp2) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &inserted, + err = pgxscan.Select(ctx, connectionSource, &inserted, `SELECT * FROM funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3 and vega_time = $4`, stores.markets[0].ID, 1, entities.FundingPeriodDataPointSourceExternal, stores.blocks[4].VegaTime) require.NoError(t, err) @@ -799,7 +799,7 @@ func TestFundingPeriodDataPointSource(t *testing.T) { } require.NoError(t, stores.fp.AddDataPoint(ctx, &dp)) got := entities.FundingPeriodDataPoint{} - require.NoError(t, pgxscan.Get(ctx, stores.fp.Connection, &got, + require.NoError(t, pgxscan.Get(ctx, stores.fp, &got, `select * from funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3`, dp.MarketID, dp.FundingPeriodSeq, dp.DataPointType), ) diff --git a/datanode/sqlstore/games_test.go b/datanode/sqlstore/games_test.go index e44ee232c3f..df3aec3f4a0 100644 --- a/datanode/sqlstore/games_test.go +++ b/datanode/sqlstore/games_test.go @@ -617,9 +617,9 @@ func setupGamesData(ctx context.Context, t *testing.T, stores gameStores, block } // IMPORTANT!!!! We MUST refresh the materialized views or the tests will fail because there will be NO DATA!!! - _, err := connectionSource.Connection.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats") + _, err := connectionSource.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats") require.NoError(t, err) - _, err = connectionSource.Connection.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats_current") + _, err = connectionSource.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats_current") require.NoError(t, err) return orderResults(results), gameIDs, rewards, teams, individuals diff --git a/datanode/sqlstore/liquidity_provision_test.go b/datanode/sqlstore/liquidity_provision_test.go index df955e7c4e0..d8b76a40bed 100644 --- a/datanode/sqlstore/liquidity_provision_test.go +++ b/datanode/sqlstore/liquidity_provision_test.go @@ -61,7 +61,7 @@ func setupLPTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.LiquidityProvision, bs := sqlstore.NewBlocks(connectionSource) lp := sqlstore.NewLiquidityProvision(connectionSource, logging.NewTestLogger()) - return bs, lp, connectionSource.Connection + return bs, lp, connectionSource } func testInsertNewInCurrentBlock(t *testing.T) { diff --git a/datanode/sqlstore/margin_level_test.go b/datanode/sqlstore/margin_level_test.go index 66c196e7899..379215b40ba 100644 --- a/datanode/sqlstore/margin_level_test.go +++ b/datanode/sqlstore/margin_level_test.go @@ -91,7 +91,7 @@ func setupMarginLevelTests(t *testing.T, ctx context.Context) (*testBlockSource, accountStore := sqlstore.NewAccounts(connectionSource) ml := sqlstore.NewMarginLevels(connectionSource) - return testBlockSource, ml, accountStore, connectionSource.Connection + return testBlockSource, ml, accountStore, connectionSource } func testInsertMarginLevels(t *testing.T) { diff --git a/datanode/sqlstore/market_data_test.go b/datanode/sqlstore/market_data_test.go index 06f1c273af6..074d1410b76 100644 --- a/datanode/sqlstore/market_data_test.go +++ b/datanode/sqlstore/market_data_test.go @@ -97,7 +97,7 @@ func shouldWorkForAllValuesOfCompositePriceType(t *testing.T) { addMarketData(t, ctx, "AUCTION_TRIGGER_LIQUIDITY", pt) var got entities.MarketData - err := connectionSource.Connection.QueryRow(ctx, `select mark_price_type from market_data`).Scan(&got.MarkPriceType) + err := connectionSource.QueryRow(ctx, `select mark_price_type from market_data`).Scan(&got.MarkPriceType) require.NoError(t, err) mdProto := got.ToProto() @@ -152,7 +152,7 @@ func shouldWorkForAllValuesOfAuctionTrigger(t *testing.T) { addMarketData(t, ctx, trigger, "COMPOSITE_PRICE_TYPE_LAST_TRADE") var got entities.MarketData - err := connectionSource.Connection.QueryRow(ctx, `select auction_trigger from market_data`).Scan(&got.AuctionTrigger) + err := connectionSource.QueryRow(ctx, `select auction_trigger from market_data`).Scan(&got.AuctionTrigger) require.NoError(t, err) mdProto := got.ToProto() @@ -170,7 +170,7 @@ func shouldInsertAValidMarketDataRecord(t *testing.T) { var rowCount int - err := connectionSource.Connection.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 0, rowCount) @@ -202,7 +202,7 @@ func shouldInsertAValidMarketDataRecord(t *testing.T) { _, err = md.Flush(ctx) require.NoError(t, err) - err = connectionSource.Connection.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } diff --git a/datanode/sqlstore/markets_test.go b/datanode/sqlstore/markets_test.go index a014e7cf23d..34d9a8de1c7 100644 --- a/datanode/sqlstore/markets_test.go +++ b/datanode/sqlstore/markets_test.go @@ -203,10 +203,9 @@ func shouldInsertAValidMarketRecord(t *testing.T) { ctx := tempTransaction(t) - conn := connectionSource.Connection var rowCount int - err := conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 0, rowCount) @@ -219,7 +218,7 @@ func shouldInsertAValidMarketRecord(t *testing.T) { err = md.Upsert(ctx, market) require.NoError(t, err, "Saving market entity to database") - err = conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } @@ -235,11 +234,10 @@ func shouldUpdateAValidMarketRecord(t *testing.T) { bs, md := setupMarketsTest(t) ctx := tempTransaction(t) - conn := connectionSource.Connection var rowCount int t.Run("should have no markets in the database", func(t *testing.T) { - err := conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 0, rowCount) }) @@ -258,7 +256,7 @@ func shouldUpdateAValidMarketRecord(t *testing.T) { require.NoError(t, err, "Saving market entity to database") var got entities.Market - err = pgxscan.Get(ctx, conn, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) + err = pgxscan.Get(ctx, connectionSource, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) assert.NoError(t, err) assert.Equal(t, "TEST_INSTRUMENT", market.InstrumentID) assert.NotNil(t, got.LiquidationStrategy) @@ -278,7 +276,7 @@ func shouldUpdateAValidMarketRecord(t *testing.T) { require.NoError(t, err, "Saving market entity to database") var got entities.Market - err = pgxscan.Get(ctx, conn, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) + err = pgxscan.Get(ctx, connectionSource, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) assert.NoError(t, err) assert.Equal(t, "TEST_INSTRUMENT", market.InstrumentID) @@ -297,7 +295,7 @@ func shouldUpdateAValidMarketRecord(t *testing.T) { require.NoError(t, err, "Saving market entity to database") var got entities.Market - err = pgxscan.Get(ctx, conn, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) + err = pgxscan.Get(ctx, connectionSource, &got, `select * from markets where id = $1 and vega_time = $2`, market.ID, market.VegaTime) assert.NoError(t, err) assert.Equal(t, "TEST_INSTRUMENT", market.InstrumentID) @@ -315,19 +313,19 @@ func shouldUpdateAValidMarketRecord(t *testing.T) { err = md.Upsert(ctx, market) require.NoError(t, err, "Saving market entity to database") - err = conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 3, rowCount) var gotFirstBlock, gotSecondBlock entities.Market - err = pgxscan.Get(ctx, conn, &gotFirstBlock, `select * from markets where id = $1 and vega_time = $2`, market.ID, block.VegaTime) + err = pgxscan.Get(ctx, connectionSource, &gotFirstBlock, `select * from markets where id = $1 and vega_time = $2`, market.ID, block.VegaTime) assert.NoError(t, err) assert.Equal(t, "TEST_INSTRUMENT", market.InstrumentID) assert.Equal(t, marketProto.TradableInstrument, gotFirstBlock.TradableInstrument.ToProto()) - err = pgxscan.Get(ctx, conn, &gotSecondBlock, `select * from markets where id = $1 and vega_time = $2`, market.ID, newBlock.VegaTime) + err = pgxscan.Get(ctx, connectionSource, &gotSecondBlock, `select * from markets where id = $1 and vega_time = $2`, market.ID, newBlock.VegaTime) assert.NoError(t, err) assert.Equal(t, "TEST_INSTRUMENT", market.InstrumentID) @@ -340,10 +338,9 @@ func shouldInsertAValidSpotMarketRecord(t *testing.T) { ctx := tempTransaction(t) - conn := connectionSource.Connection var rowCount int - err := conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 0, rowCount) @@ -356,7 +353,7 @@ func shouldInsertAValidSpotMarketRecord(t *testing.T) { err = md.Upsert(ctx, market) require.NoError(t, err, "Saving market entity to database") - err = conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } @@ -366,10 +363,9 @@ func shouldInsertAValidPerpetualMarketRecord(t *testing.T) { ctx := tempTransaction(t) - conn := connectionSource.Connection var rowCount int - err := conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, 0, rowCount) @@ -382,7 +378,7 @@ func shouldInsertAValidPerpetualMarketRecord(t *testing.T) { err = md.Upsert(ctx, market) require.NoError(t, err, "Saving market entity to database") - err = conn.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from markets`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } @@ -1286,7 +1282,6 @@ func testMarketLineageCreated(t *testing.T) { ParentMarketID: successorMarketA.ID, } - conn := connectionSource.Connection var rowCount int64 source := &testBlockSource{bs, time.Now()} @@ -1296,7 +1291,7 @@ func testMarketLineageCreated(t *testing.T) { parentMarket.State = entities.MarketStateProposed err := md.Upsert(ctx, &parentMarket) require.NoError(t, err) - err = conn.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, parentMarket.ID).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, parentMarket.ID).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, int64(0), rowCount) @@ -1315,7 +1310,7 @@ func testMarketLineageCreated(t *testing.T) { require.NoError(t, err) var marketID, parentMarketID, rootID entities.MarketID - err = conn.QueryRow(ctx, + err = connectionSource.QueryRow(ctx, `select market_id, parent_market_id, root_id from market_lineage where market_id = $1`, parentMarket.ID, ).Scan(&marketID, &parentMarketID, &rootID) @@ -1332,7 +1327,7 @@ func testMarketLineageCreated(t *testing.T) { err := md.Upsert(ctx, &successorMarketA) require.NoError(t, err) // proposed market successor only, so it should not create a lineage record yet - err = conn.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, successorMarketA.ID).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, successorMarketA.ID).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, int64(0), rowCount) @@ -1351,7 +1346,7 @@ func testMarketLineageCreated(t *testing.T) { require.NoError(t, err) // proposed market successor has been accepted and is pending, so we should now have a lineage record pointing to the parent var marketID, parentMarketID, rootID entities.MarketID - err = conn.QueryRow(ctx, + err = connectionSource.QueryRow(ctx, `select market_id, parent_market_id, root_id from market_lineage where market_id = $1`, successorMarketA.ID, ).Scan(&marketID, &parentMarketID, &rootID) @@ -1368,7 +1363,7 @@ func testMarketLineageCreated(t *testing.T) { err := md.Upsert(ctx, &successorMarketB) require.NoError(t, err) // proposed market successor only, so it should not create a lineage record yet - err = conn.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, successorMarketB.ID).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from market_lineage where market_id = $1`, successorMarketB.ID).Scan(&rowCount) require.NoError(t, err) assert.Equal(t, int64(0), rowCount) @@ -1386,7 +1381,7 @@ func testMarketLineageCreated(t *testing.T) { err = md.Upsert(ctx, &successorMarketB) require.NoError(t, err) var marketID, parentMarketID, rootID entities.MarketID - err = conn.QueryRow(ctx, + err = connectionSource.QueryRow(ctx, `select market_id, parent_market_id, root_id from market_lineage where market_id = $1`, successorMarketB.ID, ).Scan(&marketID, &parentMarketID, &rootID) diff --git a/datanode/sqlstore/node_test.go b/datanode/sqlstore/node_test.go index b977d2d1870..5e1f736d964 100644 --- a/datanode/sqlstore/node_test.go +++ b/datanode/sqlstore/node_test.go @@ -609,7 +609,7 @@ func TestNodeValidatorStatusEnum(t *testing.T) { } addRankingScore(t, ctx, ns, node1, score) var got entities.RankingScore - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got, ` + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, ` SELECT stake_score, performance_score, diff --git a/datanode/sqlstore/notary_test.go b/datanode/sqlstore/notary_test.go index 3a53db51063..62bb2bd9f69 100644 --- a/datanode/sqlstore/notary_test.go +++ b/datanode/sqlstore/notary_test.go @@ -39,7 +39,7 @@ func setupNotaryStoreTests(t *testing.T) (*sqlstore.Notary, *sqlstore.Blocks, sq t.Helper() ns := sqlstore.NewNotary(connectionSource) bs := sqlstore.NewBlocks(connectionSource) - return ns, bs, connectionSource.Connection + return ns, bs, connectionSource } func testAddSignatures(t *testing.T) { diff --git a/datanode/sqlstore/oracle_data_test.go b/datanode/sqlstore/oracle_data_test.go index 264c693378f..0f8d7313496 100644 --- a/datanode/sqlstore/oracle_data_test.go +++ b/datanode/sqlstore/oracle_data_test.go @@ -43,7 +43,7 @@ func setupOracleDataTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.OracleData, t.Helper() bs := sqlstore.NewBlocks(connectionSource) od := sqlstore.NewOracleData(connectionSource) - return bs, od, connectionSource.Connection + return bs, od, connectionSource } func testAddAndRetrieveOracleDataWithError(t *testing.T) { diff --git a/datanode/sqlstore/oracle_spec_test.go b/datanode/sqlstore/oracle_spec_test.go index e42949a23be..f41f807585a 100644 --- a/datanode/sqlstore/oracle_spec_test.go +++ b/datanode/sqlstore/oracle_spec_test.go @@ -46,7 +46,7 @@ func setupOracleSpecTest(t *testing.T) (*sqlstore.Blocks, *sqlstore.OracleSpec, bs := sqlstore.NewBlocks(connectionSource) os := sqlstore.NewOracleSpec(connectionSource) - return bs, os, connectionSource.Connection + return bs, os, connectionSource } func testInsertIntoNewBlock(t *testing.T) { diff --git a/datanode/sqlstore/paid_liquidity_fee_stats_test.go b/datanode/sqlstore/paid_liquidity_fee_stats_test.go index cdd42ed2901..0d10f1d3b2b 100644 --- a/datanode/sqlstore/paid_liquidity_fee_stats_test.go +++ b/datanode/sqlstore/paid_liquidity_fee_stats_test.go @@ -86,7 +86,7 @@ func testAddPaidLiquidityFeesStatsEpochIfNotExists(t *testing.T) { // Check that the stats were added var got entities.PaidLiquidityFeesStats - err = pgxscan.Get(ctx, connectionSource.Connection, &got, + err = pgxscan.Get(ctx, connectionSource, &got, `SELECT market_id, asset_id, epoch_seq, total_fees_paid, fees_paid_per_party as fees_per_party FROM paid_liquidity_fees WHERE market_id = $1 AND asset_id = $2 AND epoch_seq = $3`, market.ID, asset.ID, want.EpochSeq, diff --git a/datanode/sqlstore/party_locked_balance_test.go b/datanode/sqlstore/party_locked_balance_test.go index 97c9c494f56..ffb8e7948c4 100644 --- a/datanode/sqlstore/party_locked_balance_test.go +++ b/datanode/sqlstore/party_locked_balance_test.go @@ -168,12 +168,12 @@ func TestPartyLockedBalance_Add(t *testing.T) { var partyLockedBalances []entities.PartyLockedBalance var partyLockedBalancesCurrent []entities.PartyLockedBalance - err := pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalances, "SELECT * from party_locked_balances") + err := pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances") require.NoError(t, err) assert.Len(t, partyLockedBalances, 0) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") require.NoError(t, err) assert.Len(t, partyLockedBalancesCurrent, 0) @@ -193,14 +193,14 @@ func TestPartyLockedBalance_Add(t *testing.T) { err := plbs.Add(ctx, want) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalances, "SELECT * from party_locked_balances") + err = pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances") require.NoError(t, err) assert.Len(t, partyLockedBalances, 1) assert.Equal(t, want, partyLockedBalances[0]) t.Run("And a record into the party_locked_balances_current table if it doesn't already exist", func(t *testing.T) { - err = pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") require.NoError(t, err) assert.Len(t, partyLockedBalancesCurrent, 1) @@ -219,14 +219,14 @@ func TestPartyLockedBalance_Add(t *testing.T) { } err = plbs.Add(ctx, want2) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalances, "SELECT * from party_locked_balances order by vega_time") + err = pgxscan.Select(ctx, connectionSource, &partyLockedBalances, "SELECT * from party_locked_balances order by vega_time") require.NoError(t, err) assert.Len(t, partyLockedBalances, 2) assert.Equal(t, want, partyLockedBalances[0]) assert.Equal(t, want2, partyLockedBalances[1]) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyLockedBalancesCurrent, "SELECT * from party_locked_balances_current") require.NoError(t, err) assert.Len(t, partyLockedBalancesCurrent, 1) diff --git a/datanode/sqlstore/party_vesting_balance_test.go b/datanode/sqlstore/party_vesting_balance_test.go index be50be84d3b..98f940db800 100644 --- a/datanode/sqlstore/party_vesting_balance_test.go +++ b/datanode/sqlstore/party_vesting_balance_test.go @@ -44,12 +44,12 @@ func TestPartyVestingBalance_Add(t *testing.T) { var partyVestingBalances []entities.PartyVestingBalance var partyVestingBalancesCurrent []entities.PartyVestingBalance - err := pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalances, "SELECT * from party_vesting_balances") + err := pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances") require.NoError(t, err) assert.Len(t, partyVestingBalances, 0) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") require.NoError(t, err) assert.Len(t, partyVestingBalancesCurrent, 0) @@ -68,14 +68,14 @@ func TestPartyVestingBalance_Add(t *testing.T) { err := plbs.Add(ctx, want) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalances, "SELECT * from party_vesting_balances") + err = pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances") require.NoError(t, err) assert.Len(t, partyVestingBalances, 1) assert.Equal(t, want, partyVestingBalances[0]) t.Run("And a record into the party_vesting_balances_current table if it doesn't already exist", func(t *testing.T) { - err = pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") require.NoError(t, err) assert.Len(t, partyVestingBalancesCurrent, 1) @@ -93,14 +93,14 @@ func TestPartyVestingBalance_Add(t *testing.T) { } err = plbs.Add(ctx, want2) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalances, "SELECT * from party_vesting_balances order by vega_time") + err = pgxscan.Select(ctx, connectionSource, &partyVestingBalances, "SELECT * from party_vesting_balances order by vega_time") require.NoError(t, err) assert.Len(t, partyVestingBalances, 2) assert.Equal(t, want, partyVestingBalances[0]) assert.Equal(t, want2, partyVestingBalances[1]) - err = pgxscan.Select(ctx, connectionSource.Connection, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") + err = pgxscan.Select(ctx, connectionSource, &partyVestingBalancesCurrent, "SELECT * from party_vesting_balances_current") require.NoError(t, err) assert.Len(t, partyVestingBalancesCurrent, 1) diff --git a/datanode/sqlstore/proposals_test.go b/datanode/sqlstore/proposals_test.go index 856765e9b9d..df66ebe7b10 100644 --- a/datanode/sqlstore/proposals_test.go +++ b/datanode/sqlstore/proposals_test.go @@ -1120,11 +1120,11 @@ func createPaginationTestProposals(t *testing.T, ctx context.Context, pps *sqlst func cleanupTestProposals(t *testing.T) { t.Helper() // Remove the proposals, then the parties and then the blocks - _, err := connectionSource.Connection.Exec(context.Background(), `TRUNCATE TABLE proposals`) + _, err := connectionSource.Exec(context.Background(), `TRUNCATE TABLE proposals`) require.NoError(t, err) - _, err = connectionSource.Connection.Exec(context.Background(), `TRUNCATE TABLE parties`) + _, err = connectionSource.Exec(context.Background(), `TRUNCATE TABLE parties`) require.NoError(t, err) - _, err = connectionSource.Connection.Exec(context.Background(), `TRUNCATE TABLE blocks`) + _, err = connectionSource.Exec(context.Background(), `TRUNCATE TABLE blocks`) require.NoError(t, err) } diff --git a/datanode/sqlstore/referral_programs_test.go b/datanode/sqlstore/referral_programs_test.go index 96377f9a5c6..6d1cdf39356 100644 --- a/datanode/sqlstore/referral_programs_test.go +++ b/datanode/sqlstore/referral_programs_test.go @@ -121,7 +121,7 @@ func TestReferralPrograms_AddReferralProgram(t *testing.T) { require.NoError(t, err) var got []entities.ReferralProgram - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM referral_programs") require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, *want, got[0]) @@ -130,7 +130,7 @@ func TestReferralPrograms_AddReferralProgram(t *testing.T) { err = rs.AddReferralProgram(ctx, want2) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM referral_programs") require.NoError(t, err) require.Len(t, got, 2) wantAll := []entities.ReferralProgram{*want, *want2} @@ -238,7 +238,7 @@ func TestReferralPrograms_UpdateReferralProgram(t *testing.T) { require.NoError(t, err) var got []entities.ReferralProgram - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM referral_programs") require.NoError(t, err) require.Len(t, got, 1) @@ -249,7 +249,7 @@ func TestReferralPrograms_UpdateReferralProgram(t *testing.T) { err = rs.UpdateReferralProgram(ctx, wantUpdated) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM referral_programs") require.NoError(t, err) require.Len(t, got, 2) @@ -260,7 +260,7 @@ func TestReferralPrograms_UpdateReferralProgram(t *testing.T) { t.Run("The current_referral view should list the updated referral program record", func(t *testing.T) { var got []entities.ReferralProgram - err := pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM current_referral_program") + err := pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM current_referral_program") require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, *wantUpdated, got[0]) @@ -294,13 +294,13 @@ func TestReferralPrograms_EndReferralProgram(t *testing.T) { ended.EndedAt = &endTime var got []entities.ReferralProgram - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_programs order by vega_time") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM referral_programs order by vega_time") require.NoError(t, err) require.Len(t, got, 3) wantAll := []entities.ReferralProgram{*started, *updated, *ended} assert.Equal(t, wantAll, got) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM current_referral_program") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM current_referral_program") require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, *ended, got[0]) diff --git a/datanode/sqlstore/referral_sets_test.go b/datanode/sqlstore/referral_sets_test.go index 07d09482f73..06b474757d2 100644 --- a/datanode/sqlstore/referral_sets_test.go +++ b/datanode/sqlstore/referral_sets_test.go @@ -66,7 +66,7 @@ func TestReferralSets_AddReferralSet(t *testing.T) { require.NoError(t, err) var got entities.ReferralSet - err = pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_sets WHERE id = $1", set.ID) + err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_sets WHERE id = $1", set.ID) require.NoError(t, err) assert.Equal(t, set, got) }) @@ -111,7 +111,7 @@ func TestReferralSets_RefereeJoinedReferralSet(t *testing.T) { require.NoError(t, err) var got entities.ReferralSetReferee - err = pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_set_referees WHERE referral_set_id = $1 AND referee = $2", set.ID, referee.ID) + err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_referees WHERE referral_set_id = $1 AND referee = $2", set.ID, referee.ID) require.NoError(t, err) assert.Equal(t, setReferee, got) }) @@ -555,7 +555,7 @@ func TestReferralSets_AddReferralSetStats(t *testing.T) { require.NoError(t, err) var got entities.ReferralSetStats - err = pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) + err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) require.NoError(t, err) assert.Equal(t, stats, got) }) @@ -578,7 +578,7 @@ func TestReferralSets_AddReferralSetStats(t *testing.T) { err := rs.AddReferralSetStats(ctx, &stats) require.NoError(t, err) var got entities.ReferralSetStats - err = pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) + err = pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM referral_set_stats WHERE set_id = $1 AND at_epoch = $2", set.ID, epoch) require.NoError(t, err) assert.Equal(t, stats, got) diff --git a/datanode/sqlstore/rewards_test.go b/datanode/sqlstore/rewards_test.go index 0f948f8145a..c36a04ffe43 100644 --- a/datanode/sqlstore/rewards_test.go +++ b/datanode/sqlstore/rewards_test.go @@ -930,7 +930,7 @@ func TestRewardsGameTotals(t *testing.T) { }, } for _, team := range teams { - _, err := connectionSource.Connection.Exec(ctx, + _, err := connectionSource.Exec(ctx, `INSERT INTO teams (id, referrer, name, team_url, avatar_url, closed, created_at_epoch, created_at, vega_time) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)`, team.ID, team.Referrer, team.Name, team.TeamURL, team.AvatarURL, team.Closed, team.CreatedAtEpoch, team.CreatedAt, team.VegaTime) @@ -961,7 +961,7 @@ func TestRewardsGameTotals(t *testing.T) { }, } for _, member := range teamMembers { - _, err := connectionSource.Connection.Exec(ctx, + _, err := connectionSource.Exec(ctx, `INSERT INTO team_members (team_id, party_id, joined_at_epoch, joined_at, vega_time) VALUES ($1, $2, $3, $4, $5)`, member.TeamID, member.PartyID, member.JoinedAtEpoch, member.JoinedAt, member.VegaTime) @@ -1001,7 +1001,7 @@ func TestRewardsGameTotals(t *testing.T) { }, } for _, total := range existingTotals { - _, err := connectionSource.Connection.Exec(ctx, + _, err := connectionSource.Exec(ctx, `INSERT INTO game_reward_totals (game_id, party_id, asset_id, market_id, epoch_id, team_id, total_rewards, total_rewards_quantum) VALUES ($1, $2, $3, $4, $5, $6, $7, $8)`, total.GameID, total.PartyID, total.AssetID, total.MarketID, total.EpochID, total.TeamID, total.TotalRewards, total.TotalRewardsQuantum) @@ -1149,7 +1149,7 @@ func TestRewardsGameTotals(t *testing.T) { } for _, tc := range testCases { var totals []entities.RewardTotals - require.NoError(t, pgxscan.Select(ctx, connectionSource.Connection, &totals, + require.NoError(t, pgxscan.Select(ctx, connectionSource, &totals, `SELECT * FROM game_reward_totals WHERE game_id = $1 AND party_id = $2 AND epoch_id = $3`, tc.game_id, tc.party_id, tc.epoch_id)) assert.Equal(t, 1, len(totals)) diff --git a/datanode/sqlstore/risk_factor_test.go b/datanode/sqlstore/risk_factor_test.go index b8aa45a7e01..791b390da84 100644 --- a/datanode/sqlstore/risk_factor_test.go +++ b/datanode/sqlstore/risk_factor_test.go @@ -89,7 +89,7 @@ func testAddRiskFactor(t *testing.T) { bs, rfStore := setupRiskFactorTests(t) var rowCount int - err := connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) block := addTestBlock(t, ctx, bs) @@ -100,7 +100,7 @@ func testAddRiskFactor(t *testing.T) { err = rfStore.Upsert(ctx, riskFactor) require.NoError(t, err) - err = connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } @@ -111,7 +111,7 @@ func testUpsertDuplicateMarketInSameBlock(t *testing.T) { bs, rfStore := setupRiskFactorTests(t) var rowCount int - err := connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) block := addTestBlock(t, ctx, bs) @@ -122,14 +122,14 @@ func testUpsertDuplicateMarketInSameBlock(t *testing.T) { err = rfStore.Upsert(ctx, riskFactor) require.NoError(t, err) - err = connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) err = rfStore.Upsert(ctx, riskFactor) require.NoError(t, err) - err = connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) } @@ -148,7 +148,7 @@ func testGetMarketRiskFactors(t *testing.T) { bs, rfStore := setupRiskFactorTests(t) var rowCount int - err := connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) block := addTestBlock(t, ctx, bs) @@ -159,7 +159,7 @@ func testGetMarketRiskFactors(t *testing.T) { err = rfStore.Upsert(ctx, riskFactor) require.NoError(t, err) - err = connectionSource.Connection.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) + err = connectionSource.QueryRow(ctx, `select count(*) from risk_factors`).Scan(&rowCount) assert.NoError(t, err) assert.Equal(t, 1, rowCount) diff --git a/datanode/sqlstore/snapshot_data_test.go b/datanode/sqlstore/snapshot_data_test.go index 4e7458c6661..6f61471cf44 100644 --- a/datanode/sqlstore/snapshot_data_test.go +++ b/datanode/sqlstore/snapshot_data_test.go @@ -44,7 +44,7 @@ func TestGetSnapshots(t *testing.T) { addSnapshot(t, ctx, ss, bs, entities.CoreSnapshotData{BlockHeight: 100, VegaCoreVersion: "v0.65.0"}) var rowCount int - err := connectionSource.Connection.QueryRow(ctx, `select count(*) from core_snapshots`).Scan(&rowCount) + err := connectionSource.QueryRow(ctx, `select count(*) from core_snapshots`).Scan(&rowCount) require.NoError(t, err) require.Equal(t, 1, rowCount) diff --git a/datanode/sqlstore/stake_linking_test.go b/datanode/sqlstore/stake_linking_test.go index c348e4fea63..39313dfe714 100644 --- a/datanode/sqlstore/stake_linking_test.go +++ b/datanode/sqlstore/stake_linking_test.go @@ -51,8 +51,7 @@ func testUpsertShouldAddNewInBlock(t *testing.T) { bs, sl := setupStakeLinkingTest(t) var rowCount int - conn := connectionSource.Connection - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 0, rowCount) block := addTestBlock(t, ctx, bs) @@ -63,7 +62,7 @@ func testUpsertShouldAddNewInBlock(t *testing.T) { require.NoError(t, err) assert.NoError(t, sl.Upsert(ctx, data)) - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 1, rowCount) } @@ -71,9 +70,8 @@ func testUpsertShouldUpdateExistingInBlock(t *testing.T) { ctx := tempTransaction(t) bs, sl := setupStakeLinkingTest(t) - conn := connectionSource.Connection var rowCount int - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 0, rowCount) block := addTestBlock(t, ctx, bs) @@ -85,7 +83,7 @@ func testUpsertShouldUpdateExistingInBlock(t *testing.T) { assert.NoError(t, sl.Upsert(ctx, data)) } - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 2, rowCount) } @@ -93,10 +91,9 @@ func testGetStake(t *testing.T) { ctx := tempTransaction(t) bs, sl := setupStakeLinkingTest(t) - conn := connectionSource.Connection var rowCount int - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 0, rowCount) block := addTestBlock(t, ctx, bs) @@ -108,7 +105,7 @@ func testGetStake(t *testing.T) { assert.NoError(t, sl.Upsert(ctx, data)) } - assert.NoError(t, conn.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) + assert.NoError(t, connectionSource.QueryRow(ctx, "select count(*) from stake_linking").Scan(&rowCount)) assert.Equal(t, 2, rowCount) partyID := entities.PartyID("cafed00d") @@ -505,7 +502,7 @@ func testStakeLinkingTypeEnum(t *testing.T) { var got entities.StakeLinking - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM stake_linking where tx_hash = $1", data.TxHash)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM stake_linking where tx_hash = $1", data.TxHash)) assert.Equal(t, data.StakeLinkingType, got.StakeLinkingType) }) } diff --git a/datanode/sqlstore/stop_orders_test.go b/datanode/sqlstore/stop_orders_test.go index d38a79f4032..72153705323 100644 --- a/datanode/sqlstore/stop_orders_test.go +++ b/datanode/sqlstore/stop_orders_test.go @@ -130,7 +130,7 @@ func TestStopOrders_Add(t *testing.T) { require.NoError(t, err) } - rows, err := connectionSource.Connection.Query(ctx, "select * from stop_orders") + rows, err := connectionSource.Query(ctx, "select * from stop_orders") require.NoError(t, err) assert.False(t, rows.Next()) @@ -140,7 +140,7 @@ func TestStopOrders_Add(t *testing.T) { assert.Len(t, orders, len(stopOrders)) var results []entities.StopOrder - err = pgxscan.Select(ctx, connectionSource.Connection, &results, "select * from stop_orders") + err = pgxscan.Select(ctx, connectionSource, &results, "select * from stop_orders") require.NoError(t, err) assert.Len(t, results, len(stopOrders)) assert.ElementsMatch(t, results, orders) diff --git a/datanode/sqlstore/teams_test.go b/datanode/sqlstore/teams_test.go index 16ff8522899..49ba2397259 100644 --- a/datanode/sqlstore/teams_test.go +++ b/datanode/sqlstore/teams_test.go @@ -65,7 +65,7 @@ func TestTeams_AddTeams(t *testing.T) { require.NoError(t, err) var teamFromDB entities.Team - err = pgxscan.Get(ctx, connectionSource.Connection, &teamFromDB, `SELECT * FROM teams WHERE id=$1`, team.ID) + err = pgxscan.Get(ctx, connectionSource, &teamFromDB, `SELECT * FROM teams WHERE id=$1`, team.ID) require.NoError(t, err) require.Equal(t, team, teamFromDB) }) @@ -129,7 +129,7 @@ func TestTeams_UpdateTeam(t *testing.T) { var got entities.Team - err = pgxscan.Get(ctx, connectionSource.Connection, &got, `SELECT * FROM teams WHERE id=$1`, team.ID) + err = pgxscan.Get(ctx, connectionSource, &got, `SELECT * FROM teams WHERE id=$1`, team.ID) require.NoError(t, err) assert.Equal(t, want, got) @@ -188,7 +188,7 @@ func testTeamsShouldAddReferee(t *testing.T) { assert.NoError(t, ts.RefereeJoinedTeam(ctx, teamReferee)) var got entities.TeamMember - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got, `SELECT * FROM team_members WHERE team_id=$1 AND party_id=$2`, team.ID, referee.ID)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, `SELECT * FROM team_members WHERE team_id=$1 AND party_id=$2`, team.ID, referee.ID)) assert.Equal(t, teamReferee, &got) } @@ -235,7 +235,7 @@ func testTeamsShouldShowJoinedTeamAsCurrentTeam(t *testing.T) { assert.NoError(t, ts.RefereeJoinedTeam(ctx, entities.TeamRefereeFromProto(joinEvent1, block.VegaTime))) var got1 entities.TeamMember - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got1, `SELECT * FROM current_team_members WHERE party_id=$1`, referee1.ID)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got1, `SELECT * FROM current_team_members WHERE party_id=$1`, referee1.ID)) assert.Equal(t, team1.ID, (&got1).TeamID) referee2 := addTestParty(t, ctx, ps, block) @@ -249,7 +249,7 @@ func testTeamsShouldShowJoinedTeamAsCurrentTeam(t *testing.T) { assert.NoError(t, ts.RefereeJoinedTeam(ctx, entities.TeamRefereeFromProto(joinEvent2, block.VegaTime))) var got2 entities.TeamMember - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got2, `SELECT * FROM current_team_members WHERE party_id=$1`, referee2.ID)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got2, `SELECT * FROM current_team_members WHERE party_id=$1`, referee2.ID)) assert.Equal(t, team2.ID, (&got2).TeamID) } @@ -300,7 +300,7 @@ func testTeamsShouldShowLastJoinedTeamAsCurrentTeam(t *testing.T) { assert.NoError(t, ts.RefereeJoinedTeam(ctx, entities.TeamRefereeFromProto(joinEvent1, block.VegaTime))) var got1 entities.TeamMember - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got1, `SELECT * FROM current_team_members WHERE party_id=$1`, referee.ID)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got1, `SELECT * FROM current_team_members WHERE party_id=$1`, referee.ID)) assert.Equal(t, team1.ID, (&got1).TeamID) joinEvent2 := &eventspb.RefereeJoinedTeam{ @@ -312,7 +312,7 @@ func testTeamsShouldShowLastJoinedTeamAsCurrentTeam(t *testing.T) { assert.NoError(t, ts.RefereeJoinedTeam(ctx, entities.TeamRefereeFromProto(joinEvent2, block.VegaTime))) var got2 entities.TeamMember - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got2, `SELECT * FROM current_team_members WHERE party_id=$1`, referee.ID)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got2, `SELECT * FROM current_team_members WHERE party_id=$1`, referee.ID)) assert.Equal(t, team2.ID, (&got2).TeamID) } diff --git a/datanode/sqlstore/time_weighted_notional_position_test.go b/datanode/sqlstore/time_weighted_notional_position_test.go index c0a4fd26460..31c9527d4fc 100644 --- a/datanode/sqlstore/time_weighted_notional_position_test.go +++ b/datanode/sqlstore/time_weighted_notional_position_test.go @@ -44,7 +44,7 @@ func TestTimeWeightedNotionalPosition_Upsert(t *testing.T) { err := tw.Upsert(ctx, want) require.NoError(t, err) var got entities.TimeWeightedNotionalPosition - err = pgxscan.Get(ctx, connectionSource.Connection, &got, + err = pgxscan.Get(ctx, connectionSource, &got, `SELECT * FROM time_weighted_notional_positions WHERE asset_id = $1 AND party_id = $2 and game_id = $3 and epoch_seq = $4`, want.AssetID, want.PartyID, want.GameID, want.EpochSeq) require.NoError(t, err) @@ -67,7 +67,7 @@ func TestTimeWeightedNotionalPosition_Upsert(t *testing.T) { err = tw.Upsert(ctx, want) require.NoError(t, err) var got entities.TimeWeightedNotionalPosition - err = pgxscan.Get(ctx, connectionSource.Connection, &got, + err = pgxscan.Get(ctx, connectionSource, &got, `SELECT * FROM time_weighted_notional_positions WHERE asset_id = $1 AND party_id = $2 and game_id = $3 and epoch_seq = $4`, want.AssetID, want.PartyID, want.GameID, want.EpochSeq) require.NoError(t, err) diff --git a/datanode/sqlstore/volume_discount_programs_test.go b/datanode/sqlstore/volume_discount_programs_test.go index 8fb3d6403ba..22e3058f2ec 100644 --- a/datanode/sqlstore/volume_discount_programs_test.go +++ b/datanode/sqlstore/volume_discount_programs_test.go @@ -93,7 +93,7 @@ func TestVolumeDiscountPrograms_AddVolumeDiscountProgram(t *testing.T) { require.NoError(t, err) var got []entities.VolumeDiscountProgram - require.NoError(t, pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_programs")) + require.NoError(t, pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM volume_discount_programs")) require.Len(t, got, 1) assert.Equal(t, *want, got[0]) @@ -101,7 +101,7 @@ func TestVolumeDiscountPrograms_AddVolumeDiscountProgram(t *testing.T) { err = rs.AddVolumeDiscountProgram(ctx, want2) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM volume_discount_programs") require.NoError(t, err) require.Len(t, got, 2) wantAll := []entities.VolumeDiscountProgram{*want, *want2} @@ -181,7 +181,7 @@ func TestVolumeDiscountPrograms_UpdateVolumeDiscountProgram(t *testing.T) { require.NoError(t, err) var got []entities.VolumeDiscountProgram - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM volume_discount_programs") require.NoError(t, err) require.Len(t, got, 1) @@ -192,7 +192,7 @@ func TestVolumeDiscountPrograms_UpdateVolumeDiscountProgram(t *testing.T) { err = rs.UpdateVolumeDiscountProgram(ctx, wantUpdated) require.NoError(t, err) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_programs") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM volume_discount_programs") require.NoError(t, err) require.Len(t, got, 2) @@ -203,7 +203,7 @@ func TestVolumeDiscountPrograms_UpdateVolumeDiscountProgram(t *testing.T) { t.Run("The current_referral view should list the updated referral program record", func(t *testing.T) { var got []entities.VolumeDiscountProgram - err := pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM current_volume_discount_program") + err := pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM current_volume_discount_program") require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, *wantUpdated, got[0]) @@ -237,13 +237,13 @@ func TestVolumeDiscountPrograms_EndVolumeDiscountProgram(t *testing.T) { ended.EndedAt = &endTime var got []entities.VolumeDiscountProgram - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_programs order by vega_time") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM volume_discount_programs order by vega_time") require.NoError(t, err) require.Len(t, got, 3) wantAll := []entities.VolumeDiscountProgram{*started, *updated, *ended} assert.Equal(t, wantAll, got) - err = pgxscan.Select(ctx, connectionSource.Connection, &got, "SELECT * FROM current_volume_discount_program") + err = pgxscan.Select(ctx, connectionSource, &got, "SELECT * FROM current_volume_discount_program") require.NoError(t, err) require.Len(t, got, 1) assert.Equal(t, *ended, got[0]) diff --git a/datanode/sqlstore/volume_discount_stats_test.go b/datanode/sqlstore/volume_discount_stats_test.go index 960b589b39c..d9d0547bfd2 100644 --- a/datanode/sqlstore/volume_discount_stats_test.go +++ b/datanode/sqlstore/volume_discount_stats_test.go @@ -54,7 +54,7 @@ func TestVolumeDiscountStats_AddVolumeDiscountStats(t *testing.T) { require.NoError(t, vds.Add(ctx, &stats)) var got entities.VolumeDiscountStats - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch)) assert.Equal(t, stats, got) }) @@ -70,7 +70,7 @@ func TestVolumeDiscountStats_AddVolumeDiscountStats(t *testing.T) { require.NoError(t, vds.Add(ctx, &stats)) var got entities.VolumeDiscountStats - require.NoError(t, pgxscan.Get(ctx, connectionSource.Connection, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch)) + require.NoError(t, pgxscan.Get(ctx, connectionSource, &got, "SELECT * FROM volume_discount_stats WHERE at_epoch = $1", epoch)) assert.Equal(t, stats, got) err := vds.Add(ctx, &stats) diff --git a/datanode/sqlstore/withdrawals_test.go b/datanode/sqlstore/withdrawals_test.go index 4ce7643a17c..089804618d6 100644 --- a/datanode/sqlstore/withdrawals_test.go +++ b/datanode/sqlstore/withdrawals_test.go @@ -70,7 +70,7 @@ func setupWithdrawalStoreTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.Withdr t.Helper() bs := sqlstore.NewBlocks(connectionSource) ws := sqlstore.NewWithdrawals(connectionSource) - return bs, ws, connectionSource.Connection + return bs, ws, connectionSource } func testAddWithdrawalForNewBlock(t *testing.T) {