Skip to content

Commit

Permalink
fix: add support for nested transactions, remove dead code
Browse files Browse the repository at this point in the history
Signed-off-by: Elias Van Ootegem <elias@vega.xyz>
  • Loading branch information
EVODelavega committed Jun 12, 2024
1 parent 7ee64b9 commit dfdd3ba
Show file tree
Hide file tree
Showing 58 changed files with 550 additions and 468 deletions.
16 changes: 8 additions & 8 deletions datanode/sqlstore/accounts.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func NewAccounts(connectionSource *ConnectionSource) *Accounts {
func (as *Accounts) Add(ctx context.Context, a *entities.Account) error {
defer metrics.StartSQLQuery("Accounts", "Add")()

err := as.Connection.QueryRow(ctx,
err := as.QueryRow(ctx,
`INSERT INTO accounts(id, party_id, asset_id, market_id, type, tx_hash, vega_time)
VALUES ($1, $2, $3, $4, $5, $6, $7)
RETURNING id`,
Expand Down Expand Up @@ -87,7 +87,7 @@ func (as *Accounts) GetByID(ctx context.Context, accountID entities.AccountID) (
a := entities.Account{}
defer metrics.StartSQLQuery("Accounts", "GetByID")()

if err := pgxscan.Get(ctx, as.Connection, &a,
if err := pgxscan.Get(ctx, as.ConnectionSource, &a,
`SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time
FROM accounts WHERE id=$1`,
accountID,
Expand All @@ -102,7 +102,7 @@ func (as *Accounts) GetByID(ctx context.Context, accountID entities.AccountID) (
func (as *Accounts) GetAll(ctx context.Context) ([]entities.Account, error) {
accounts := []entities.Account{}
defer metrics.StartSQLQuery("Accounts", "GetAll")()
err := pgxscan.Select(ctx, as.Connection, &accounts, `
err := pgxscan.Select(ctx, as.ConnectionSource, &accounts, `
SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time
FROM accounts`)
return accounts, err
Expand All @@ -114,7 +114,7 @@ func (as *Accounts) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]

err := pgxscan.Select(
ctx,
as.Connection,
as.ConnectionSource,
&accounts,
`SELECT id, party_id, asset_id, market_id, type, tx_hash, vega_time FROM accounts WHERE tx_hash=$1`,
txHash,
Expand Down Expand Up @@ -160,7 +160,7 @@ func (as *Accounts) Obtain(ctx context.Context, a *entities.Account) error {
batch.Queue(insertQuery, accountID, a.PartyID, a.AssetID, a.MarketID, a.Type, a.TxHash, a.VegaTime)
batch.Queue(selectQuery, a.PartyID, a.AssetID, a.MarketID, a.Type)
defer metrics.StartSQLQuery("Accounts", "Obtain")()
results := as.Connection.SendBatch(ctx, &batch)
results := as.SendBatch(ctx, &batch)
defer results.Close()

if _, err := results.Exec(); err != nil {
Expand Down Expand Up @@ -204,7 +204,7 @@ func (as *Accounts) Query(ctx context.Context, filter entities.AccountFilter) ([
accs := []entities.Account{}

defer metrics.StartSQLQuery("Accounts", "Query")()
rows, err := as.Connection.Query(ctx, query, args...)
rows, err := as.ConnectionSource.Query(ctx, query, args...)
if err != nil {
return accs, fmt.Errorf("querying accounts: %w", err)
}
Expand Down Expand Up @@ -234,7 +234,7 @@ func (as *Accounts) QueryBalances(ctx context.Context,
defer metrics.StartSQLQuery("Accounts", "QueryBalances")()

accountBalances := make([]entities.AccountBalance, 0)
rows, err := as.Connection.Query(ctx, query, args...)
rows, err := as.ConnectionSource.Query(ctx, query, args...)
if err != nil {
return accountBalances, entities.PageInfo{}, fmt.Errorf("querying account balances: %w", err)
}
Expand All @@ -254,7 +254,7 @@ func (as *Accounts) GetBalancesByTxHash(ctx context.Context, txHash entities.TxH

err := pgxscan.Select(
ctx,
as.Connection,
as.ConnectionSource,
&balances,
fmt.Sprintf("%s WHERE balances.tx_hash=$1", accountBalancesQuery()),
txHash,
Expand Down
14 changes: 7 additions & 7 deletions datanode/sqlstore/amm_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ func NewAMMPools(connectionSource *ConnectionSource) *AMMPools {

func (p *AMMPools) Upsert(ctx context.Context, pool entities.AMMPool) error {
defer metrics.StartSQLQuery("AMMs", "UpsertAMM")
if _, err := p.Connection.Exec(ctx, `
if _, err := p.ConnectionSource.Exec(ctx, `
insert into amms(party_id, market_id, id, amm_party_id,
commitment, status, status_reason, parameters_base,
parameters_lower_bound, parameters_upper_bound,
Expand Down Expand Up @@ -107,28 +107,28 @@ func listBy[T entities.AMMPoolsFilter](ctx context.Context, connection Connectio

func (p *AMMPools) ListByMarket(ctx context.Context, marketID entities.MarketID, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
defer metrics.StartSQLQuery("AMMs", "ListByMarket")
return listBy(ctx, p.Connection, "market_id", &marketID, pagination)
return listBy(ctx, p.ConnectionSource, "market_id", &marketID, pagination)
}

func (p *AMMPools) ListByParty(ctx context.Context, partyID entities.PartyID, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
defer metrics.StartSQLQuery("AMMs", "ListByParty")

return listBy(ctx, p.Connection, "party_id", &partyID, pagination)
return listBy(ctx, p.ConnectionSource, "party_id", &partyID, pagination)
}

func (p *AMMPools) ListByPool(ctx context.Context, poolID entities.AMMPoolID, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
defer metrics.StartSQLQuery("AMMs", "ListByPool")
return listBy(ctx, p.Connection, "id", &poolID, pagination)
return listBy(ctx, p.ConnectionSource, "id", &poolID, pagination)
}

func (p *AMMPools) ListBySubAccount(ctx context.Context, ammPartyID entities.PartyID, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
defer metrics.StartSQLQuery("AMMs", "ListByAMMParty")
return listBy(ctx, p.Connection, "amm_party_id", &ammPartyID, pagination)
return listBy(ctx, p.ConnectionSource, "amm_party_id", &ammPartyID, pagination)
}

func (p *AMMPools) ListByStatus(ctx context.Context, status entities.AMMStatus, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
defer metrics.StartSQLQuery("AMMs", "ListByStatus")
return listBy(ctx, p.Connection, "status", &status, pagination)
return listBy(ctx, p.ConnectionSource, "status", &status, pagination)
}

func (p *AMMPools) ListAll(ctx context.Context, pagination entities.CursorPagination) ([]entities.AMMPool, entities.PageInfo, error) {
Expand All @@ -144,7 +144,7 @@ func (p *AMMPools) ListAll(ctx context.Context, pagination entities.CursorPagina
return nil, pageInfo, err
}

if err := pgxscan.Select(ctx, p.Connection, &pools, query, args...); err != nil {
if err := pgxscan.Select(ctx, p.ConnectionSource, &pools, query, args...); err != nil {
return nil, pageInfo, fmt.Errorf("could not list AMMs: %w", err)
}

Expand Down
10 changes: 5 additions & 5 deletions datanode/sqlstore/assets.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ func NewAssets(connectionSource *ConnectionSource) *Assets {

func (as *Assets) Add(ctx context.Context, a entities.Asset) error {
defer metrics.StartSQLQuery("Assets", "Add")()
_, err := as.Connection.Exec(ctx,
_, err := as.Exec(ctx,
`INSERT INTO assets(id, name, symbol, decimals, quantum, source, erc20_contract, lifetime_limit, withdraw_threshold, tx_hash, vega_time, status, chain_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (id, vega_time) DO UPDATE SET
Expand Down Expand Up @@ -102,7 +102,7 @@ func (as *Assets) GetByID(ctx context.Context, id string) (entities.Asset, error
a := entities.Asset{}

defer metrics.StartSQLQuery("Assets", "GetByID")()
err := pgxscan.Get(ctx, as.Connection, &a,
err := pgxscan.Get(ctx, as.ConnectionSource, &a,
getAssetQuery(ctx)+` WHERE id=$1`,
entities.AssetID(id))

Expand All @@ -116,7 +116,7 @@ func (as *Assets) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]en
defer metrics.StartSQLQuery("Assets", "GetByTxHash")()

var assets []entities.Asset
err := pgxscan.Select(ctx, as.Connection, &assets, `SELECT * FROM assets WHERE tx_hash=$1`, txHash)
err := pgxscan.Select(ctx, as.ConnectionSource, &assets, `SELECT * FROM assets WHERE tx_hash=$1`, txHash)
if err != nil {
return nil, as.wrapE(err)
}
Expand All @@ -127,7 +127,7 @@ func (as *Assets) GetByTxHash(ctx context.Context, txHash entities.TxHash) ([]en
func (as *Assets) GetAll(ctx context.Context) ([]entities.Asset, error) {
assets := []entities.Asset{}
defer metrics.StartSQLQuery("Assets", "GetAll")()
err := pgxscan.Select(ctx, as.Connection, &assets, getAssetQuery(ctx))
err := pgxscan.Select(ctx, as.ConnectionSource, &assets, getAssetQuery(ctx))
return assets, err
}

Expand All @@ -146,7 +146,7 @@ func (as *Assets) GetAllWithCursorPagination(ctx context.Context, pagination ent
}
defer metrics.StartSQLQuery("Assets", "GetAllWithCursorPagination")()

if err = pgxscan.Select(ctx, as.Connection, &assets, query, args...); err != nil {
if err = pgxscan.Select(ctx, as.ConnectionSource, &assets, query, args...); err != nil {
return nil, pageInfo, fmt.Errorf("could not get assets: %w", err)
}

Expand Down
4 changes: 2 additions & 2 deletions datanode/sqlstore/balances.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func NewBalances(connectionSource *ConnectionSource) *Balances {

func (bs *Balances) Flush(ctx context.Context) ([]entities.AccountBalance, error) {
defer metrics.StartSQLQuery("Balances", "Flush")()
return bs.batcher.Flush(ctx, bs.Connection)
return bs.batcher.Flush(ctx, bs.ConnectionSource)
}

// Add inserts a row to the balance table. If there's already a balance for this
Expand Down Expand Up @@ -99,7 +99,7 @@ func (bs *Balances) Query(ctx context.Context, filter entities.AccountFilter, da
}

defer metrics.StartSQLQuery("Balances", "Query")()
rows, err := bs.Connection.Query(ctx, query, args...)
rows, err := bs.ConnectionSource.Query(ctx, query, args...)
if err != nil {
return nil, pageInfo, fmt.Errorf("querying balances: %w", err)
}
Expand Down
10 changes: 5 additions & 5 deletions datanode/sqlstore/blocks.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func NewBlocks(connectionSource *ConnectionSource) *Blocks {
func (bs *Blocks) Add(ctx context.Context, b entities.Block) error {
defer metrics.StartSQLQuery("Blocks", "Add")()

_, err := bs.Connection.Exec(ctx,
_, err := bs.Exec(ctx,
`insert into blocks(vega_time, height, hash) values ($1, $2, $3)`,
b.VegaTime, b.Height, b.Hash)
if err != nil {
Expand All @@ -64,15 +64,15 @@ func (bs *Blocks) Add(ctx context.Context, b entities.Block) error {
func (bs *Blocks) GetAll(ctx context.Context) ([]entities.Block, error) {
defer metrics.StartSQLQuery("Blocks", "GetAll")()
blocks := []entities.Block{}
err := pgxscan.Select(ctx, bs.Connection, &blocks,
err := pgxscan.Select(ctx, bs.ConnectionSource, &blocks,
`SELECT vega_time, height, hash
FROM blocks
ORDER BY vega_time desc`)
return blocks, err
}

func (bs *Blocks) GetAtHeight(ctx context.Context, height int64) (entities.Block, error) {
connection := bs.Connection
connection := bs.ConnectionSource
defer metrics.StartSQLQuery("Blocks", "GetAtHeight")()

// Check if it's in our cache first
Expand Down Expand Up @@ -102,7 +102,7 @@ func (bs *Blocks) GetLastBlock(ctx context.Context) (entities.Block, error) {
}
defer metrics.StartSQLQuery("Blocks", "GetLastBlock")()

lastBlock, err := bs.getLastBlockUsingConnection(ctx, bs.Connection)
lastBlock, err := bs.getLastBlockUsingConnection(ctx, bs.ConnectionSource)
// FIXME(woot?): why do we set that before checking for error, that would clearly fuckup the cache or something innit?
bs.lastBlock = lastBlock
if err != nil {
Expand All @@ -121,7 +121,7 @@ func (bs *Blocks) setLastBlock(b entities.Block) {
func (bs *Blocks) GetOldestHistoryBlock(ctx context.Context) (entities.Block, error) {
defer metrics.StartSQLQuery("Blocks", "GetOldestHistoryBlock")()

return bs.getOldestHistoryBlockUsingConnection(ctx, bs.Connection)
return bs.getOldestHistoryBlockUsingConnection(ctx, bs.ConnectionSource)
}

func (bs *Blocks) getOldestHistoryBlockUsingConnection(ctx context.Context, connection Connection) (entities.Block, error) {
Expand Down
12 changes: 6 additions & 6 deletions datanode/sqlstore/candles.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ func (cs *Candles) getCandlesSubquery(ctx context.Context, descriptor candleDesc
if from == nil || to == nil {
datesQuery := fmt.Sprintf("select min(period_start) as start_date, max(period_start) as end_date from %s where market_id = $1", descriptor.view)
marketID := entities.MarketID(descriptor.market)
err := pgxscan.Get(ctx, cs.Connection, &candlesDateRange, datesQuery, marketID)
err := pgxscan.Get(ctx, cs.ConnectionSource, &candlesDateRange, datesQuery, marketID)
if err != nil {
return "", args, fmt.Errorf("querying candles date range: %w", err)
}
Expand Down Expand Up @@ -195,7 +195,7 @@ func (cs *Candles) GetCandleDataForTimeSpan(ctx context.Context, candleID string
query = fmt.Sprintf("with gap_filled_candles as (%s) %s", subQuery, query)

defer metrics.StartSQLQuery("Candles", "GetCandleDataForTimeSpan")()
err = pgxscan.Select(ctx, cs.Connection, &candles, query, args...)
err = pgxscan.Select(ctx, cs.ConnectionSource, &candles, query, args...)
if err != nil {
return nil, pageInfo, fmt.Errorf("querying candles: %w", err)
}
Expand Down Expand Up @@ -244,7 +244,7 @@ func (cs *Candles) getIntervalToView(ctx context.Context) (map[string]string, er
query := fmt.Sprintf("SELECT table_name AS view_name FROM INFORMATION_SCHEMA.views WHERE table_name LIKE '%s%%'",
candlesViewNamePrePend)
defer metrics.StartSQLQuery("Candles", "GetIntervalToView")()
rows, err := cs.Connection.Query(ctx, query)
rows, err := cs.Query(ctx, query)
if err != nil {
return nil, fmt.Errorf("fetching existing views for interval: %w", err)
}
Expand Down Expand Up @@ -324,13 +324,13 @@ func (cs *Candles) normaliseInterval(ctx context.Context, interval string) (stri
var normalizedInterval string

defer metrics.StartSQLQuery("Candles", "normaliseInterval")()
_, err := cs.Connection.Exec(ctx, "SET intervalstyle = 'postgres_verbose' ")
_, err := cs.Exec(ctx, "SET intervalstyle = 'postgres_verbose' ")
if err != nil {
return "", fmt.Errorf("normalising interval, failed to set interval style:%w", err)
}

query := fmt.Sprintf("select cast( INTERVAL '%s' as text)", interval)
row := cs.Connection.QueryRow(ctx, query)
row := cs.QueryRow(ctx, query)

err = row.Scan(&normalizedInterval)
if err != nil {
Expand All @@ -347,7 +347,7 @@ func (cs *Candles) getIntervalSeconds(ctx context.Context, interval string) (int

defer metrics.StartSQLQuery("Candles", "getIntervalSeconds")()
query := fmt.Sprintf("SELECT EXTRACT(epoch FROM INTERVAL '%s')", interval)
row := cs.Connection.QueryRow(ctx, query)
row := cs.QueryRow(ctx, query)

err := row.Scan(&seconds)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions datanode/sqlstore/chain.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ func (c *Chain) Get(ctx context.Context) (entities.Chain, error) {
chain := entities.Chain{}

query := `SELECT id from chain`
return chain, c.wrapE(pgxscan.Get(ctx, c.Connection, &chain, query))
return chain, c.wrapE(pgxscan.Get(ctx, c.ConnectionSource, &chain, query))
}

func (c *Chain) Set(ctx context.Context, chain entities.Chain) error {
defer metrics.StartSQLQuery("Chain", "Set")()
query := `INSERT INTO chain(id) VALUES($1)`
_, err := c.Connection.Exec(ctx, query, chain.ID)
_, err := c.Exec(ctx, query, chain.ID)
if e, ok := err.(*pgconn.PgError); ok {
// 23505 is postgres error code for a unique constraint violation
if e.Code == "23505" {
Expand Down
4 changes: 2 additions & 2 deletions datanode/sqlstore/checkpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ func NewCheckpoints(connectionSource *ConnectionSource) *Checkpoints {

func (c *Checkpoints) Add(ctx context.Context, r entities.Checkpoint) error {
defer metrics.StartSQLQuery("Checkpoints", "Add")()
_, err := c.Connection.Exec(ctx,
_, err := c.Exec(ctx,
`INSERT INTO checkpoints(
hash,
block_hash,
Expand All @@ -70,7 +70,7 @@ func (c *Checkpoints) GetAll(ctx context.Context, pagination entities.CursorPagi
return nps, pageInfo, err
}

if err = pgxscan.Select(ctx, c.Connection, &nps, query, args...); err != nil {
if err = pgxscan.Select(ctx, c.ConnectionSource, &nps, query, args...); err != nil {
return nil, pageInfo, fmt.Errorf("could not get checkpoint data: %w", err)
}

Expand Down
Loading

0 comments on commit dfdd3ba

Please sign in to comment.