Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(postgres): ensure consistency of prepare order, fix #69 #70

Merged
merged 2 commits into from
Dec 30, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 58 additions & 37 deletions packages/postgres/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,12 @@
commit_action: null
}

interface QueryTask {
sql: string
resolve: (value: any) => void
reject: (reason: unknown) => void
}

function escapeId(value: string) {
return '"' + value.replace(/"/g, '""') + '"'
}
Expand Down Expand Up @@ -436,6 +442,7 @@

private session?: postgres.TransactionSql
private _counter = 0
private _queryTasks: QueryTask[] = []

constructor(database: Database, config: PostgresDriver.Config) {
super(database)
Expand Down Expand Up @@ -474,18 +481,39 @@
})
}

queue<T extends any[] = any[]>(sql: string, values?: any): Promise<T> {
if (this.session) {
return this.query(sql)
}

return new Promise<any>((resolve, reject) => {
this._queryTasks.push({ sql, resolve, reject })
process.nextTick(() => this._flushTasks())
})
}

private async _flushTasks() {
const tasks = this._queryTasks
if (!tasks.length) return
this._queryTasks = []

try {
let results = await this.query(tasks.map(task => task.sql).join(';\n')) as any
if (tasks.length === 1) results = [results]
tasks.forEach((task, index) => {
task.resolve(results[index])
})
} catch (error) {
tasks.forEach(task => task.reject(error))
}

Check warning on line 508 in packages/postgres/src/index.ts

View check run for this annotation

Codecov / codecov/patch

packages/postgres/src/index.ts#L507-L508

Added lines #L507 - L508 were not covered by tests
}

async prepare(name: string) {
const [columns, constraints] = await Promise.all([
this.query<ColumnInfo[]>(`
SELECT *
FROM information_schema.columns
WHERE table_schema = 'public'
AND table_name = ${this.sql.escape(name)}`),
this.query<ConstraintInfo[]>(`
SELECT *
FROM information_schema.table_constraints
WHERE table_schema = 'public'
AND table_name = ${this.sql.escape(name)}`),
this.queue<ColumnInfo[]>(`SELECT * FROM information_schema.columns WHERE table_schema = 'public' AND table_name = ${this.sql.escape(name)}`),
this.queue<ConstraintInfo[]>(
`SELECT * FROM information_schema.table_constraints WHERE table_schema = 'public' AND table_name = ${this.sql.escape(name)}`,
),
])

const table = this.model(name)
Expand Down Expand Up @@ -587,27 +615,19 @@
await this.query(`DROP TABLE IF EXISTS ${escapeId(table)} CASCADE`)
return
}
const tables: TableInfo[] = await this.query(`
SELECT *
FROM information_schema.tables
WHERE table_schema = 'public'`)
const tables: TableInfo[] = await this.queue(`SELECT * FROM information_schema.tables WHERE table_schema = 'public'`)
if (!tables.length) return
await this.query(`DROP TABLE IF EXISTS ${tables.map(t => escapeId(t.table_name)).join(',')} CASCADE`)
}

async stats(): Promise<Partial<Driver.Stats>> {
const names = Object.keys(this.database.tables)
const tables = (await this.query<TableInfo[]>(`
SELECT *
FROM information_schema.tables
WHERE table_schema = 'public'`))
const tables = (await this.queue<TableInfo[]>(`SELECT * FROM information_schema.tables WHERE table_schema = 'public'`))
.map(t => t.table_name).filter(name => names.includes(name))
const tableStats = await this.query(
tables.map(name => {
return `SELECT '${name}' AS name,
pg_total_relation_size('${escapeId(name)}') AS size,
COUNT(*) AS count FROM ${escapeId(name)}`
}).join(' UNION '),
const tableStats = await this.queue(
tables.map(
(name) => `SELECT '${name}' AS name, pg_total_relation_size('${escapeId(name)}') AS size, COUNT(*) AS count FROM ${escapeId(name)}`,
).join(' UNION '),
).then(s => s.map(t => [t.name, { size: +t.size, count: +t.count }]))

return {
Expand All @@ -620,7 +640,7 @@
const builder = new PostgresBuilder(sel.tables)
const query = builder.get(sel)
if (!query) return []
return this.query(query).then(data => {
return this.queue(query).then(data => {
return data.map(row => builder.load(sel.model, row))
})
}
Expand All @@ -630,7 +650,7 @@
const inner = builder.get(sel.table as Selection, true, true)
const output = builder.parseEval(expr, false)
const ref = isBracketed(inner) ? sel.ref : ''
const [data] = await this.query(`SELECT ${output} AS value FROM ${inner} ${ref}`)
const [data] = await this.queue(`SELECT ${output} AS value FROM ${inner} ${ref}`)
return builder.load(data?.value)
}

Expand Down Expand Up @@ -665,10 +685,11 @@
const builder = new PostgresBuilder(sel.tables)
const formatted = builder.dump(model, data)
const keys = Object.keys(formatted)
const [row] = await this.query(`
INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})
VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})
RETURNING *`)
const [row] = await this.query([
`INSERT INTO ${builder.escapeId(table)} (${keys.map(builder.escapeId).join(', ')})`,
`VALUES (${keys.map(key => builder.escape(formatted[key])).join(', ')})`,
`RETURNING *`,
].join(' '))
return builder.load(model, row)
}

Expand Down Expand Up @@ -731,13 +752,13 @@
return `${escaped} = ${value}`
}).join(', ')

const result = await this.query(`
INSERT INTO ${builder.escapeId(table)} (${initFields.map(builder.escapeId).join(', ')})
VALUES (${insertion.map(item => formatValues(table, item, initFields)).join('), (')})
ON CONFLICT (${keys.map(builder.escapeId).join(', ')})
DO UPDATE SET ${update}, _pg_mtime = ${mtime}
RETURNING _pg_mtime as rtime
`)
const result = await this.query([
`INSERT INTO ${builder.escapeId(table)} (${initFields.map(builder.escapeId).join(', ')})`,
`VALUES (${insertion.map(item => formatValues(table, item, initFields)).join('), (')})`,
`ON CONFLICT (${keys.map(builder.escapeId).join(', ')})`,
`DO UPDATE SET ${update}, _pg_mtime = ${mtime}`,
`RETURNING _pg_mtime as rtime`,
].join(' '))
return { inserted: result.filter(({ rtime }) => +rtime !== mtime).length, matched: result.filter(({ rtime }) => +rtime === mtime).length }
}

Expand Down