Skip to content

Commit

Permalink
test: update unit tests to use the new conneciton source
Browse files Browse the repository at this point in the history
Signed-off-by: Elias Van Ootegem <elias@vega.xyz>
  • Loading branch information
EVODelavega committed Jun 12, 2024
1 parent f14687a commit 2c410c8
Show file tree
Hide file tree
Showing 33 changed files with 132 additions and 122 deletions.
6 changes: 3 additions & 3 deletions datanode/sqlstore/amm_pool_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func TestAMMPool_Upsert(t *testing.T) {
var upserted entities.AMMPool
require.NoError(t, pgxscan.Get(
ctx,
connectionSource.Connection,
connectionSource,
&upserted,
`SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`,
partyID, marketID, poolID, ammPartyID))
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestAMMPool_Upsert(t *testing.T) {
var upserted entities.AMMPool
require.NoError(t, pgxscan.Get(
ctx,
connectionSource.Connection,
connectionSource,
&upserted,
`SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`,
partyID, marketID, poolID, ammPartyID))
Expand Down Expand Up @@ -152,7 +152,7 @@ func TestAMMPool_Upsert(t *testing.T) {
var upserted entities.AMMPool
require.NoError(t, pgxscan.Get(
ctx,
connectionSource.Connection,
connectionSource,
&upserted,
`SELECT * FROM amms WHERE party_id = $1 AND market_id = $2 AND id = $3 AND amm_party_id = $4`,
partyID, marketID, poolID, ammPartyID))
Expand Down
9 changes: 8 additions & 1 deletion datanode/sqlstore/assets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,11 +120,18 @@ func TestAssetCache(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, asset2, fetched)

// Commit the transaction and fetch the asset, we should get the asset with the new symbol
// Commit the sub-transaction and fetch the asset, we should not yet get the asset with the new symbol
err = connectionSource.Commit(txCtx)
require.NoError(t, err)
fetched, err = as.GetByID(ctx, string(asset.ID))
require.NoError(t, err)
assert.Equal(t, asset2, fetched)

// now commit the main transaction, then we should get the new symbol
err = connectionSource.Commit(ctx)
require.NoError(t, err)
fetched, err = as.GetByID(context.Background(), string(asset.ID))
require.NoError(t, err)
assert.Equal(t, asset3, fetched)
}

Expand Down
4 changes: 0 additions & 4 deletions datanode/sqlstore/connection_source.go
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,6 @@ func registerNumericType(poolConfig *pgxpool.Config) {
}
}

type delegatingConnection struct {
pool *pgxpool.Pool
}

func CreateConnectionPool(ctx context.Context, conf ConnectionConfig) (*pgxpool.Pool, error) {
poolConfig, err := conf.GetPoolConfig()
if err != nil {
Expand Down
15 changes: 15 additions & 0 deletions datanode/sqlstore/connection_tx.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
// Copyright (C) 2023 Gobalsky Labs Limited
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU Affero General Public License as
// published by the Free Software Foundation, either version 3 of the
// License, or (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU Affero General Public License for more details.
//
// You should have received a copy of the GNU Affero General Public License
// along with this program. If not, see <http://www.gnu.org/licenses/>.

package sqlstore

import (
Expand Down
2 changes: 1 addition & 1 deletion datanode/sqlstore/deposits_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func setupDepositStoreTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.Deposits,
t.Helper()
bs := sqlstore.NewBlocks(connectionSource)
ds := sqlstore.NewDeposits(connectionSource)
return bs, ds, connectionSource.Connection
return bs, ds, connectionSource
}

func testAddDepositForNewBlock(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion datanode/sqlstore/erc20_multisig_added_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestERC20MultiSigEvent(t *testing.T) {
func setupERC20MultiSigEventStoreTests(t *testing.T) (*sqlstore.ERC20MultiSigSignerEvent, sqlstore.Connection) {
t.Helper()
ms := sqlstore.NewERC20MultiSigSignerEvent(connectionSource)
return ms, connectionSource.Connection
return ms, connectionSource
}

func testAddSigner(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion datanode/sqlstore/fees_stats_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func testAddFeesStatsEpochNotExists(t *testing.T) {

// Check that the stats were added
var got entities.FeesStats
err = pgxscan.Get(ctx, connectionSource.Connection, &got,
err = pgxscan.Get(ctx, connectionSource, &got,
`SELECT * FROM fees_stats WHERE market_id = $1 AND asset_id = $2 AND epoch_seq = $3`,
market.ID, asset.ID, want.EpochSeq,
)
Expand Down
10 changes: 5 additions & 5 deletions datanode/sqlstore/funding_period_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func testAddFundingPeriodShouldUpdateIfMarketExistsAndSequenceExists(t *testing.
require.NoError(t, err)

var dbResult entities.FundingPeriod
err = pgxscan.Get(ctx, stores.fp.Connection, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1)
err = pgxscan.Get(ctx, stores.fp, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1)
require.NoError(t, err)
assert.Equal(t, period, dbResult)

Expand All @@ -139,7 +139,7 @@ func testAddFundingPeriodShouldUpdateIfMarketExistsAndSequenceExists(t *testing.
err = stores.fp.AddFundingPeriod(ctx, &period)
require.NoError(t, err)

err = pgxscan.Get(ctx, stores.fp.Connection, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1)
err = pgxscan.Get(ctx, stores.fp, &dbResult, `select * from funding_period where market_id = $1 and funding_period_seq = $2`, stores.markets[0].ID, 1)
require.NoError(t, err)
assert.Equal(t, period, dbResult)
}
Expand Down Expand Up @@ -241,7 +241,7 @@ func testShouldUpdateDataPointInSameBlock(t *testing.T) {
require.NoError(t, err)

var inserted []entities.FundingPeriodDataPoint
err = pgxscan.Select(ctx, connectionSource.Connection, &inserted,
err = pgxscan.Select(ctx, connectionSource, &inserted,
`SELECT * FROM funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3 and vega_time = $4`,
stores.markets[0].ID, 1, entities.FundingPeriodDataPointSourceExternal, stores.blocks[4].VegaTime)
require.NoError(t, err)
Expand All @@ -262,7 +262,7 @@ func testShouldUpdateDataPointInSameBlock(t *testing.T) {
err = stores.fp.AddDataPoint(ctx, &dp2)
require.NoError(t, err)

err = pgxscan.Select(ctx, connectionSource.Connection, &inserted,
err = pgxscan.Select(ctx, connectionSource, &inserted,
`SELECT * FROM funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3 and vega_time = $4`,
stores.markets[0].ID, 1, entities.FundingPeriodDataPointSourceExternal, stores.blocks[4].VegaTime)
require.NoError(t, err)
Expand Down Expand Up @@ -799,7 +799,7 @@ func TestFundingPeriodDataPointSource(t *testing.T) {
}
require.NoError(t, stores.fp.AddDataPoint(ctx, &dp))
got := entities.FundingPeriodDataPoint{}
require.NoError(t, pgxscan.Get(ctx, stores.fp.Connection, &got,
require.NoError(t, pgxscan.Get(ctx, stores.fp, &got,
`select * from funding_period_data_points where market_id = $1 and funding_period_seq = $2 and data_point_type = $3`,
dp.MarketID, dp.FundingPeriodSeq, dp.DataPointType),
)
Expand Down
4 changes: 2 additions & 2 deletions datanode/sqlstore/games_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -617,9 +617,9 @@ func setupGamesData(ctx context.Context, t *testing.T, stores gameStores, block
}

// IMPORTANT!!!! We MUST refresh the materialized views or the tests will fail because there will be NO DATA!!!
_, err := connectionSource.Connection.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats")
_, err := connectionSource.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats")
require.NoError(t, err)
_, err = connectionSource.Connection.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats_current")
_, err = connectionSource.Exec(ctx, "REFRESH MATERIALIZED VIEW game_stats_current")
require.NoError(t, err)

return orderResults(results), gameIDs, rewards, teams, individuals
Expand Down
2 changes: 1 addition & 1 deletion datanode/sqlstore/liquidity_provision_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ func setupLPTests(t *testing.T) (*sqlstore.Blocks, *sqlstore.LiquidityProvision,
bs := sqlstore.NewBlocks(connectionSource)
lp := sqlstore.NewLiquidityProvision(connectionSource, logging.NewTestLogger())

return bs, lp, connectionSource.Connection
return bs, lp, connectionSource
}

func testInsertNewInCurrentBlock(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion datanode/sqlstore/margin_level_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func setupMarginLevelTests(t *testing.T, ctx context.Context) (*testBlockSource,
accountStore := sqlstore.NewAccounts(connectionSource)
ml := sqlstore.NewMarginLevels(connectionSource)

return testBlockSource, ml, accountStore, connectionSource.Connection
return testBlockSource, ml, accountStore, connectionSource
}

func testInsertMarginLevels(t *testing.T) {
Expand Down
8 changes: 4 additions & 4 deletions datanode/sqlstore/market_data_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func shouldWorkForAllValuesOfCompositePriceType(t *testing.T) {
addMarketData(t, ctx, "AUCTION_TRIGGER_LIQUIDITY", pt)
var got entities.MarketData

err := connectionSource.Connection.QueryRow(ctx, `select mark_price_type from market_data`).Scan(&got.MarkPriceType)
err := connectionSource.QueryRow(ctx, `select mark_price_type from market_data`).Scan(&got.MarkPriceType)
require.NoError(t, err)

mdProto := got.ToProto()
Expand Down Expand Up @@ -152,7 +152,7 @@ func shouldWorkForAllValuesOfAuctionTrigger(t *testing.T) {
addMarketData(t, ctx, trigger, "COMPOSITE_PRICE_TYPE_LAST_TRADE")
var got entities.MarketData

err := connectionSource.Connection.QueryRow(ctx, `select auction_trigger from market_data`).Scan(&got.AuctionTrigger)
err := connectionSource.QueryRow(ctx, `select auction_trigger from market_data`).Scan(&got.AuctionTrigger)
require.NoError(t, err)

mdProto := got.ToProto()
Expand All @@ -170,7 +170,7 @@ func shouldInsertAValidMarketDataRecord(t *testing.T) {

var rowCount int

err := connectionSource.Connection.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount)
err := connectionSource.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount)
require.NoError(t, err)
assert.Equal(t, 0, rowCount)

Expand Down Expand Up @@ -202,7 +202,7 @@ func shouldInsertAValidMarketDataRecord(t *testing.T) {
_, err = md.Flush(ctx)
require.NoError(t, err)

err = connectionSource.Connection.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount)
err = connectionSource.QueryRow(ctx, `select count(*) from market_data`).Scan(&rowCount)
assert.NoError(t, err)
assert.Equal(t, 1, rowCount)
}
Expand Down
Loading

0 comments on commit 2c410c8

Please sign in to comment.