From 3a356d20a6103b710dcc35b8e2ec8785ed981951 Mon Sep 17 00:00:00 2001 From: Shigma Date: Sat, 10 Feb 2024 14:15:14 +0800 Subject: [PATCH] refa: separate builder class --- packages/core/package.json | 2 +- packages/core/src/database.ts | 44 +++-- packages/mysql/src/builder.ts | 170 ++++++++++++++++++ packages/mysql/src/index.ts | 174 +----------------- packages/postgres/src/builder.ts | 293 ++++++++++++++++++++++++++++++ packages/postgres/src/index.ts | 297 +------------------------------ packages/sqlite/src/builder.ts | 106 +++++++++++ packages/sqlite/src/index.ts | 108 +---------- 8 files changed, 608 insertions(+), 586 deletions(-) create mode 100644 packages/mysql/src/builder.ts create mode 100644 packages/postgres/src/builder.ts create mode 100644 packages/sqlite/src/builder.ts diff --git a/packages/core/package.json b/packages/core/package.json index 5b2f4838..05a6793c 100644 --- a/packages/core/package.json +++ b/packages/core/package.json @@ -42,7 +42,7 @@ "postgres" ], "dependencies": { - "cordis": "^3.9.1", + "cordis": "^3.9.2", "cosmokit": "^1.5.2" } } diff --git a/packages/core/src/database.ts b/packages/core/src/database.ts index 5c9833ac..ec5d5791 100644 --- a/packages/core/src/database.ts +++ b/packages/core/src/database.ts @@ -14,26 +14,36 @@ type TableType> = : T extends Selection ? U : never -type TableMap1[]> = Intersect< - | M extends readonly (infer K extends Keys)[] - ? { [P in K]: TableType } - : never -> +export namespace Join1 { + export type Input = readonly Keys[] + + export type Output> = Intersect< + | U extends readonly (infer K extends Keys)[] + ? { [P in K]: TableType } + : never + > -type TableMap2>> = { - [K in keyof U]: TableType + type Parameters> = + | U extends readonly [infer K extends Keys, ...infer R] + ? [Row, ...Parameters>>] + : [] + + export type Predicate> = (...args: Parameters) => Eval.Expr } -type JoinParameters[]> = - | U extends readonly [infer K extends Keys, ...infer R] - ? [Row, ...JoinParameters[]>>] - : [] +export namespace Join2 { + export type Input = Dict> -type JoinCallback1[]> = (...args: JoinParameters) => Eval.Expr + export type Output> = { + [K in keyof U]: TableType + } -type JoinCallback2>> = (args: { - [K in keyof U]: Row> -}) => Eval.Expr + type Parameters> = { + [K in keyof U]: Row> + } + + export type Predicate> = (args: Parameters) => Eval.Expr +} const kTransaction = Symbol('transaction') @@ -105,8 +115,8 @@ export class Database extends Service { return new Selection(this.getDriver(table), table, query) } - join[]>(tables: U, callback?: JoinCallback1, optional?: boolean[]): Selection> - join>>(tables: U, callback?: JoinCallback2, optional?: Dict>): Selection> + join>(tables: U, callback?: Join1.Predicate, optional?: boolean[]): Selection> + join>(tables: U, callback?: Join2.Predicate, optional?: Dict>): Selection> join(tables: any, query?: any, optional?: any) { if (Array.isArray(tables)) { const sel = new Selection(this.getDriver(tables[0]), Object.fromEntries(tables.map((name) => [name, this.select(name)]))) diff --git a/packages/mysql/src/builder.ts b/packages/mysql/src/builder.ts new file mode 100644 index 00000000..3ff13103 --- /dev/null +++ b/packages/mysql/src/builder.ts @@ -0,0 +1,170 @@ +import { Builder, escapeId, isBracketed } from '@minatojs/sql-utils' +import { Dict, Time } from 'cosmokit' +import { Field, isEvalExpr, Model, randomId, Selection } from 'minato' + +export const DEFAULT_DATE = new Date('1970-01-01') + +export interface Compat { + maria?: boolean + maria105?: boolean + mysql57?: boolean +} + +export class MySQLBuilder extends Builder { + // eslint-disable-next-line no-control-regex + protected escapeRegExp = /[\0\b\t\n\r\x1a'"\\]/g + protected escapeMap = { + '\0': '\\0', + '\b': '\\b', + '\t': '\\t', + '\n': '\\n', + '\r': '\\r', + '\x1a': '\\Z', + '\"': '\\\"', + '\'': '\\\'', + '\\': '\\\\', + } + + prequeries: string[] = [] + + constructor(tables?: Dict, private compat: Compat = {}) { + super(tables) + + this.evalOperators.$sum = (expr) => this.createAggr(expr, value => `ifnull(sum(${value}), 0)`, undefined, value => `ifnull(minato_cfunc_sum(${value}), 0)`) + this.evalOperators.$avg = (expr) => this.createAggr(expr, value => `avg(${value})`, undefined, value => `minato_cfunc_avg(${value})`) + this.evalOperators.$min = (expr) => this.createAggr(expr, value => `min(${value})`, undefined, value => `minato_cfunc_min(${value})`) + this.evalOperators.$max = (expr) => this.createAggr(expr, value => `max(${value})`, undefined, value => `minato_cfunc_max(${value})`) + + this.define({ + types: ['list'], + dump: value => value.join(','), + load: value => value ? value.split(',') : [], + }) + + this.define({ + types: ['json'], + dump: value => JSON.stringify(value), + load: value => typeof value === 'string' ? JSON.parse(value) : value, + }) + + this.define({ + types: ['time'], + dump: value => value, + load: (value) => { + if (!value || typeof value === 'object') return value + const time = new Date(DEFAULT_DATE) + const [h, m, s] = value.split(':') + time.setHours(parseInt(h)) + time.setMinutes(parseInt(m)) + time.setSeconds(parseInt(s)) + return time + }, + }) + } + + escape(value: any, field?: Field) { + if (value instanceof Date) { + value = Time.template('yyyy-MM-dd hh:mm:ss', value) + } else if (value instanceof RegExp) { + value = value.source + } else if (!field && !!value && typeof value === 'object') { + return `json_extract(${this.quote(JSON.stringify(value))}, '$')` + } + return super.escape(value, field) + } + + protected jsonQuote(value: string, pure: boolean = false) { + if (pure) return this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)` + const res = this.state.sqlType === 'raw' ? (this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)`) : value + this.state.sqlType = 'json' + return res + } + + protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, compat?: (value: string) => string) { + if (!this.state.group && compat && (this.compat.mysql57 || this.compat.maria)) { + return compat(this.parseEval(expr, false)) + } else { + return super.createAggr(expr, aggr, nonaggr) + } + } + + protected groupArray(value: string) { + if (!this.compat.maria) return super.groupArray(value) + const res = this.state.sqlType === 'json' ? `concat('[', group_concat(${value}), ']')` + : `concat('[', group_concat(json_extract(json_object('v', ${value}), '$.v')), ']')` + this.state.sqlType = 'json' + return `ifnull(${res}, json_array())` + } + + protected parseSelection(sel: Selection) { + if (!this.compat.maria && !this.compat.mysql57) return super.parseSelection(sel) + const { args: [expr], ref, table, tables } = sel + const restore = this.saveState({ wrappedSubquery: true, tables }) + const inner = this.get(table as Selection, true, true) as string + const output = this.parseEval(expr, false) + const refFields = this.state.refFields + restore() + let query: string + if (!(sel.args[0] as any).$) { + query = `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` + } else { + query = `(ifnull((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` + } + if (Object.keys(refFields ?? {}).length) { + const funcname = `minato_tfunc_${randomId()}` + const decls = Object.values(refFields ?? {}).map(x => `${x} JSON`).join(',') + const args = Object.keys(refFields ?? {}).map(x => this.state.refFields?.[x] ?? x).map(x => this.jsonQuote(x, true)).join(',') + query = this.state.sqlType === 'json' ? `ifnull(${query}, json_array())` : this.jsonQuote(query) + this.prequeries.push(`DROP FUNCTION IF EXISTS ${funcname}`) + this.prequeries.push(`CREATE FUNCTION ${funcname} (${decls}) RETURNS JSON DETERMINISTIC RETURN ${query}`) + this.state.sqlType = 'json' + return `${funcname}(${args})` + } else return query + } + + toUpdateExpr(item: any, key: string, field?: Field, upsert?: boolean) { + const escaped = escapeId(key) + + // update directly + if (key in item) { + if (!isEvalExpr(item[key]) && upsert) { + return `VALUES(${escaped})` + } else if (isEvalExpr(item[key])) { + return this.parseEval(item[key]) + } else { + return this.escape(item[key], field) + } + } + + // prepare nested layout + const jsonInit = {} + for (const prop in item) { + if (!prop.startsWith(key + '.')) continue + const rest = prop.slice(key.length + 1).split('.') + if (rest.length === 1) continue + rest.reduce((obj, k) => obj[k] ??= {}, jsonInit) + } + + // update with json_set + const valueInit = `ifnull(${escaped}, '{}')` + let value = valueInit + + // json_set cannot create deeply nested property when non-exist + // therefore we merge a layout to it + if (Object.keys(jsonInit).length !== 0) { + value = `json_merge(${value}, ${this.quote(JSON.stringify(jsonInit))})` + } + + for (const prop in item) { + if (!prop.startsWith(key + '.')) continue + const rest = prop.slice(key.length + 1).split('.') + value = `json_set(${value}, '$${rest.map(key => `."${key}"`).join('')}', ${this.parseEval(item[prop])})` + } + + if (value === valueInit) { + return escaped + } else { + return value + } + } +} diff --git a/packages/mysql/src/index.ts b/packages/mysql/src/index.ts index 53e1799a..988c2ae5 100644 --- a/packages/mysql/src/index.ts +++ b/packages/mysql/src/index.ts @@ -1,8 +1,9 @@ import { createPool, format } from '@vlasky/mysql' import type { OkPacket, Pool, PoolConfig, PoolConnection } from 'mysql' -import { Dict, difference, makeArray, pick, Time } from 'cosmokit' -import { Driver, Eval, executeUpdate, Field, isEvalExpr, Model, randomId, RuntimeError, Selection, z } from 'minato' -import { Builder, escapeId, isBracketed } from '@minatojs/sql-utils' +import { Dict, difference, makeArray, pick } from 'cosmokit' +import { Driver, Eval, executeUpdate, Field, RuntimeError, Selection, z } from 'minato' +import { escapeId, isBracketed } from '@minatojs/sql-utils' +import { Compat, DEFAULT_DATE, MySQLBuilder } from './builder' declare module 'mysql' { interface UntypedFieldInfo { @@ -10,8 +11,6 @@ declare module 'mysql' { } } -const DEFAULT_DATE = new Date('1970-01-01') - function getIntegerType(length = 4) { if (length <= 1) return 'tinyint' if (length <= 2) return 'smallint' @@ -69,12 +68,6 @@ function createIndex(keys: string | string[]) { return makeArray(keys).map(escapeId).join(', ') } -interface Compat { - maria?: boolean - maria105?: boolean - mysql57?: boolean -} - interface ColumnInfo { COLUMN_NAME: string IS_NULLABLE: 'YES' | 'NO' @@ -96,165 +89,6 @@ interface QueryTask { reject: (reason: unknown) => void } -class MySQLBuilder extends Builder { - // eslint-disable-next-line no-control-regex - protected escapeRegExp = /[\0\b\t\n\r\x1a'"\\]/g - protected escapeMap = { - '\0': '\\0', - '\b': '\\b', - '\t': '\\t', - '\n': '\\n', - '\r': '\\r', - '\x1a': '\\Z', - '\"': '\\\"', - '\'': '\\\'', - '\\': '\\\\', - } - - prequeries: string[] = [] - - constructor(tables?: Dict, private compat: Compat = {}) { - super(tables) - - this.evalOperators.$sum = (expr) => this.createAggr(expr, value => `ifnull(sum(${value}), 0)`, undefined, value => `ifnull(minato_cfunc_sum(${value}), 0)`) - this.evalOperators.$avg = (expr) => this.createAggr(expr, value => `avg(${value})`, undefined, value => `minato_cfunc_avg(${value})`) - this.evalOperators.$min = (expr) => this.createAggr(expr, value => `min(${value})`, undefined, value => `minato_cfunc_min(${value})`) - this.evalOperators.$max = (expr) => this.createAggr(expr, value => `max(${value})`, undefined, value => `minato_cfunc_max(${value})`) - - this.define({ - types: ['list'], - dump: value => value.join(','), - load: value => value ? value.split(',') : [], - }) - - this.define({ - types: ['json'], - dump: value => JSON.stringify(value), - load: value => typeof value === 'string' ? JSON.parse(value) : value, - }) - - this.define({ - types: ['time'], - dump: value => value, - load: (value) => { - if (!value || typeof value === 'object') return value - const time = new Date(DEFAULT_DATE) - const [h, m, s] = value.split(':') - time.setHours(parseInt(h)) - time.setMinutes(parseInt(m)) - time.setSeconds(parseInt(s)) - return time - }, - }) - } - - escape(value: any, field?: Field) { - if (value instanceof Date) { - value = Time.template('yyyy-MM-dd hh:mm:ss', value) - } else if (value instanceof RegExp) { - value = value.source - } else if (!field && !!value && typeof value === 'object') { - return `json_extract(${this.quote(JSON.stringify(value))}, '$')` - } - return super.escape(value, field) - } - - protected jsonQuote(value: string, pure: boolean = false) { - if (pure) return this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)` - const res = this.state.sqlType === 'raw' ? (this.compat.maria ? `json_extract(json_object('v', ${value}), '$.v')` : `cast(${value} as json)`) : value - this.state.sqlType = 'json' - return res - } - - protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, compat?: (value: string) => string) { - if (!this.state.group && compat && (this.compat.mysql57 || this.compat.maria)) { - return compat(this.parseEval(expr, false)) - } else { - return super.createAggr(expr, aggr, nonaggr) - } - } - - protected groupArray(value: string) { - if (!this.compat.maria) return super.groupArray(value) - const res = this.state.sqlType === 'json' ? `concat('[', group_concat(${value}), ']')` - : `concat('[', group_concat(json_extract(json_object('v', ${value}), '$.v')), ']')` - this.state.sqlType = 'json' - return `ifnull(${res}, json_array())` - } - - protected parseSelection(sel: Selection) { - if (!this.compat.maria && !this.compat.mysql57) return super.parseSelection(sel) - const { args: [expr], ref, table, tables } = sel - const restore = this.saveState({ wrappedSubquery: true, tables }) - const inner = this.get(table as Selection, true, true) as string - const output = this.parseEval(expr, false) - const refFields = this.state.refFields - restore() - let query: string - if (!(sel.args[0] as any).$) { - query = `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` - } else { - query = `(ifnull((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), json_array()))` - } - if (Object.keys(refFields ?? {}).length) { - const funcname = `minato_tfunc_${randomId()}` - const decls = Object.values(refFields ?? {}).map(x => `${x} JSON`).join(',') - const args = Object.keys(refFields ?? {}).map(x => this.state.refFields?.[x] ?? x).map(x => this.jsonQuote(x, true)).join(',') - query = this.state.sqlType === 'json' ? `ifnull(${query}, json_array())` : this.jsonQuote(query) - this.prequeries.push(`DROP FUNCTION IF EXISTS ${funcname}`) - this.prequeries.push(`CREATE FUNCTION ${funcname} (${decls}) RETURNS JSON DETERMINISTIC RETURN ${query}`) - this.state.sqlType = 'json' - return `${funcname}(${args})` - } else return query - } - - toUpdateExpr(item: any, key: string, field?: Field, upsert?: boolean) { - const escaped = escapeId(key) - - // update directly - if (key in item) { - if (!isEvalExpr(item[key]) && upsert) { - return `VALUES(${escaped})` - } else if (isEvalExpr(item[key])) { - return this.parseEval(item[key]) - } else { - return this.escape(item[key], field) - } - } - - // prepare nested layout - const jsonInit = {} - for (const prop in item) { - if (!prop.startsWith(key + '.')) continue - const rest = prop.slice(key.length + 1).split('.') - if (rest.length === 1) continue - rest.reduce((obj, k) => obj[k] ??= {}, jsonInit) - } - - // update with json_set - const valueInit = `ifnull(${escaped}, '{}')` - let value = valueInit - - // json_set cannot create deeply nested property when non-exist - // therefore we merge a layout to it - if (Object.keys(jsonInit).length !== 0) { - value = `json_merge(${value}, ${this.quote(JSON.stringify(jsonInit))})` - } - - for (const prop in item) { - if (!prop.startsWith(key + '.')) continue - const rest = prop.slice(key.length + 1).split('.') - value = `json_set(${value}, '$${rest.map(key => `."${key}"`).join('')}', ${this.parseEval(item[prop])})` - } - - if (value === valueInit) { - return escaped - } else { - return value - } - } -} - export class MySQLDriver extends Driver { static name = 'mysql' diff --git a/packages/postgres/src/builder.ts b/packages/postgres/src/builder.ts new file mode 100644 index 00000000..2197283d --- /dev/null +++ b/packages/postgres/src/builder.ts @@ -0,0 +1,293 @@ +import { Builder, escapeId, isBracketed } from '@minatojs/sql-utils' +import { Dict, isNullable, Time } from 'cosmokit' +import { Field, isEvalExpr, Model, randomId, Selection } from 'minato' + +const timeRegex = /(\d+):(\d+):(\d+)/ + +export function formatTime(time: Date) { + const year = time.getFullYear().toString() + const month = Time.toDigits(time.getMonth() + 1) + const date = Time.toDigits(time.getDate()) + const hour = Time.toDigits(time.getHours()) + const min = Time.toDigits(time.getMinutes()) + const sec = Time.toDigits(time.getSeconds()) + const ms = Time.toDigits(time.getMilliseconds(), 3) + let timezone = Time.toDigits(time.getTimezoneOffset() / -60) + if (!timezone.startsWith('-')) timezone = `+${timezone}` + return `${year}-${month}-${date} ${hour}:${min}:${sec}.${ms}${timezone}` +} + +export class PostgresBuilder extends Builder { + // eslint-disable-next-line no-control-regex + protected escapeRegExp = /[\0\b\t\n\r\x1a'\\]/g + protected escapeMap = { + '\0': '\\0', + '\b': '\\b', + '\t': '\\t', + '\n': '\\n', + '\r': '\\r', + '\x1a': '\\Z', + '\'': '\'\'', + '\\': '\\\\', + } + + protected $true = 'TRUE' + protected $false = 'FALSE' + + constructor(public tables?: Dict) { + super(tables) + + this.queryOperators = { + ...this.queryOperators, + $regex: (key, value) => this.createRegExpQuery(key, value), + $regexFor: (key, value) => `${this.escape(value)} ~ ${key}`, + $size: (key, value) => { + if (!value) return this.logicalNot(key) + if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { + return `${this.jsonLength(key)} = ${this.escape(value)}` + } else { + return `${key} IS NOT NULL AND ARRAY_LENGTH(${key}, 1) = ${value}` + } + }, + } + + this.evalOperators = { + ...this.evalOperators, + $if: (args) => { + const type = this.getLiteralType(args[1]) ?? this.getLiteralType(args[2]) ?? 'text' + return `(SELECT CASE WHEN ${this.parseEval(args[0], 'boolean')} THEN ${this.parseEval(args[1], type)} ELSE ${this.parseEval(args[2], type)} END)` + }, + $ifNull: (args) => { + const type = args.map(this.getLiteralType).find(x => x) ?? 'text' + return `coalesce(${args.map(arg => this.parseEval(arg, type)).join(', ')})` + }, + + $regex: ([key, value]) => `${this.parseEval(key)} ~ ${this.parseEval(value)}`, + + // number + $add: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' + ')})`, + $multiply: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' * ')})`, + $modulo: ([left, right]) => { + const dividend = this.parseEval(left, 'double precision'), divisor = this.parseEval(right, 'double precision') + return `${dividend} - (${divisor} * floor(${dividend} / ${divisor}))` + }, + $log: ([left, right]) => isNullable(right) + ? `ln(${this.parseEval(left, 'double precision')})` + : `ln(${this.parseEval(left, 'double precision')}) / ln(${this.parseEval(right, 'double precision')})`, + $random: () => `random()`, + + $eq: this.binary('=', 'text'), + + $number: (arg) => { + const value = this.parseEval(arg) + const res = this.state.sqlType === 'raw' ? `${value}::double precision` + : `extract(epoch from ${value})::bigint` + this.state.sqlType = 'raw' + return `coalesce(${res}, 0)` + }, + + $sum: (expr) => this.createAggr(expr, value => `coalesce(sum(${value})::double precision, 0)`, undefined, 'double precision'), + $avg: (expr) => this.createAggr(expr, value => `avg(${value})::double precision`, undefined, 'double precision'), + $min: (expr) => this.createAggr(expr, value => `min(${value})`, undefined, 'double precision'), + $max: (expr) => this.createAggr(expr, value => `max(${value})`, undefined, 'double precision'), + $count: (expr) => this.createAggr(expr, value => `count(distinct ${value})::integer`), + $length: (expr) => this.createAggr(expr, value => `count(${value})::integer`, value => { + if (this.state.sqlType === 'json') { + this.state.sqlType = 'raw' + return `${this.jsonLength(value)}` + } else { + this.state.sqlType = 'raw' + return `COALESCE(ARRAY_LENGTH(${value}, 1), 0)` + } + }), + + $concat: (args) => `${args.map(arg => this.parseEval(arg, 'text')).join('||')}`, + } + + this.define({ + types: ['time'], + dump: date => date ? (typeof date === 'string' ? date : formatTime(date)) : null, + load: str => { + if (isNullable(str)) return str + const date = new Date(0) + const parsed = timeRegex.exec(str) + if (!parsed) throw Error(`unexpected time value: ${str}`) + date.setHours(+parsed[1], +parsed[2], +parsed[3]) + return date + }, + }) + + this.define({ + types: ['list'], + dump: value => '{' + value.join(',') + '}', + load: value => value, + }) + } + + upsert(table: string) { + this.modifiedTable = table + } + + protected binary(operator: string, eltype: string = 'double precision') { + return ([left, right]) => { + const type = this.getLiteralType(left) ?? this.getLiteralType(right) ?? eltype + return `(${this.parseEval(left, type)} ${operator} ${this.parseEval(right, type)})` + } + } + + private getLiteralType(expr: any) { + if (typeof expr === 'string') return 'text' + else if (typeof expr === 'number') return 'double precision' + else if (typeof expr === 'string') return 'boolean' + } + + parseEval(expr: any, outtype: boolean | string = false): string { + this.state.sqlType = 'raw' + if (typeof expr === 'string' || typeof expr === 'number' || typeof expr === 'boolean' || expr instanceof Date || expr instanceof RegExp) { + return this.escape(expr) + } + return outtype ? this.jsonUnquote(this.parseEvalExpr(expr), false, typeof outtype === 'string' ? outtype : undefined) : this.parseEvalExpr(expr) + } + + protected createRegExpQuery(key: string, value: string | RegExp) { + return `${key} ~ ${this.escape(typeof value === 'string' ? value : value.source)}` + } + + protected createElementQuery(key: string, value: any) { + if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { + return this.jsonContains(key, this.quote(JSON.stringify(value))) + } else { + return `${key} && ARRAY['${value}']::TEXT[]` + } + } + + protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, eltype?: string) { + if (!this.state.group && !nonaggr) { + const value = this.parseEval(expr, false) + return `(select ${aggr(this.jsonUnquote(this.escapeId('value'), true, eltype))} from jsonb_array_elements(${value}) ${randomId()})` + } else { + return super.createAggr(expr, aggr, nonaggr) + } + } + + protected transformJsonField(obj: string, path: string) { + this.state.sqlType = 'json' + return `jsonb_extract_path(${obj}, ${path.slice(1).replace('.', ',')})` + } + + protected jsonLength(value: string) { + return `jsonb_array_length(${value})` + } + + protected jsonContains(obj: string, value: string) { + return `(${obj} @> ${value})` + } + + protected jsonUnquote(value: string, pure: boolean = false, type?: string) { + if (pure && type) return `(jsonb_build_object('v', ${value})->>'v')::${type}` + if (this.state.sqlType === 'json') { + this.state.sqlType = 'raw' + return `(jsonb_build_object('v', ${value})->>'v')::${type}` + } + return value + } + + protected jsonQuote(value: string, pure: boolean = false) { + if (pure) return `to_jsonb(${value})` + if (this.state.sqlType !== 'json') { + this.state.sqlType = 'json' + return `to_jsonb(${value})` + } + return value + } + + protected groupObject(fields: any) { + const parse = (expr) => { + const value = this.parseEval(expr, false) + return this.state.sqlType === 'json' ? `to_jsonb(${value})` : `${value}` + } + const res = `jsonb_build_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr)}`).join(',') + `)` + this.state.sqlType = 'json' + return res + } + + protected groupArray(value: string) { + this.state.sqlType = 'json' + return `coalesce(jsonb_agg(${value}), '[]'::jsonb)` + } + + protected parseSelection(sel: Selection) { + const { args: [expr], ref, table, tables } = sel + const restore = this.saveState({ tables }) + const inner = this.get(table as Selection, true, true) as string + const output = this.parseEval(expr, false) + restore() + if (!(sel.args[0] as any).$) { + return `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` + } else { + return `(coalesce((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), '[]'::jsonb))` + } + } + + escapeId = escapeId + + escapeKey(value: string) { + return `'${value}'` + } + + escape(value: any, field?: Field) { + if (value instanceof Date) { + value = formatTime(value) + } else if (value instanceof RegExp) { + value = value.source + } else if (!field && !!value && typeof value === 'object') { + return `${this.quote(JSON.stringify(value))}::jsonb` + } + return super.escape(value, field) + } + + toUpdateExpr(item: any, key: string, field?: Field, upsert?: boolean) { + const escaped = this.escapeId(key) + // update directly + if (key in item) { + if (!isEvalExpr(item[key]) && upsert) { + return `excluded.${escaped}` + } else if (isEvalExpr(item[key])) { + return this.parseEval(item[key]) + } else { + return this.escape(item[key], field) + } + } + + // prepare nested layout + const jsonInit = {} + for (const prop in item) { + if (!prop.startsWith(key + '.')) continue + const rest = prop.slice(key.length + 1).split('.') + if (rest.length === 1) continue + rest.reduce((obj, k) => obj[k] ??= {}, jsonInit) + } + + // update with json_set + const valueInit = this.modifiedTable ? `coalesce(${this.escapeId(this.modifiedTable)}.${escaped}, '{}')::jsonb` : `coalesce(${escaped}, '{}')::jsonb` + let value = valueInit + + // json_set cannot create deeply nested property when non-exist + // therefore we merge a layout to it + if (Object.keys(jsonInit).length !== 0) { + value = `(${value} || jsonb ${this.quote(JSON.stringify(jsonInit))})` + } + + for (const prop in item) { + if (!prop.startsWith(key + '.')) continue + const rest = prop.slice(key.length + 1).split('.') + value = `jsonb_set(${value}, '{${rest.map(key => `"${key}"`).join(',')}}', ${this.jsonQuote(this.parseEval(item[prop]), true)}, true)` + } + + if (value === valueInit) { + return this.modifiedTable ? `${this.escapeId(this.modifiedTable)}.${escaped}` : escaped + } else { + return value + } + } +} diff --git a/packages/postgres/src/index.ts b/packages/postgres/src/index.ts index d0a2f754..f3e62fd2 100644 --- a/packages/postgres/src/index.ts +++ b/packages/postgres/src/index.ts @@ -1,9 +1,8 @@ import postgres from 'postgres' -import { Dict, difference, isNullable, makeArray, pick, Time } from 'cosmokit' -import { Driver, Eval, executeUpdate, Field, isEvalExpr, Model, randomId, Selection, z } from 'minato' -import { Builder, isBracketed } from '@minatojs/sql-utils' - -const timeRegex = /(\d+):(\d+):(\d+)/ +import { Dict, difference, isNullable, makeArray, pick } from 'cosmokit' +import { Driver, Eval, executeUpdate, Field, Selection, z } from 'minato' +import { isBracketed } from '@minatojs/sql-utils' +import { formatTime, PostgresBuilder } from './builder' interface ColumnInfo { table_catalog: string @@ -148,298 +147,10 @@ function createIndex(keys: string | string[]) { return makeArray(keys).map(escapeId).join(', ') } -function formatTime(time: Date) { - const year = time.getFullYear().toString() - const month = Time.toDigits(time.getMonth() + 1) - const date = Time.toDigits(time.getDate()) - const hour = Time.toDigits(time.getHours()) - const min = Time.toDigits(time.getMinutes()) - const sec = Time.toDigits(time.getSeconds()) - const ms = Time.toDigits(time.getMilliseconds(), 3) - let timezone = Time.toDigits(time.getTimezoneOffset() / -60) - if (!timezone.startsWith('-')) timezone = `+${timezone}` - return `${year}-${month}-${date} ${hour}:${min}:${sec}.${ms}${timezone}` -} - function transformArray(arr: any[]) { return `ARRAY[${arr.map(v => `'${v.replace(/'/g, "''")}'`).join(',')}]::TEXT[]` } -class PostgresBuilder extends Builder { - // eslint-disable-next-line no-control-regex - protected escapeRegExp = /[\0\b\t\n\r\x1a'\\]/g - protected escapeMap = { - '\0': '\\0', - '\b': '\\b', - '\t': '\\t', - '\n': '\\n', - '\r': '\\r', - '\x1a': '\\Z', - '\'': '\'\'', - '\\': '\\\\', - } - - protected $true = 'TRUE' - protected $false = 'FALSE' - - constructor(public tables?: Dict) { - super(tables) - - this.queryOperators = { - ...this.queryOperators, - $regex: (key, value) => this.createRegExpQuery(key, value), - $regexFor: (key, value) => `${this.escape(value)} ~ ${key}`, - $size: (key, value) => { - if (!value) return this.logicalNot(key) - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return `${this.jsonLength(key)} = ${this.escape(value)}` - } else { - return `${key} IS NOT NULL AND ARRAY_LENGTH(${key}, 1) = ${value}` - } - }, - } - - this.evalOperators = { - ...this.evalOperators, - $if: (args) => { - const type = this.getLiteralType(args[1]) ?? this.getLiteralType(args[2]) ?? 'text' - return `(SELECT CASE WHEN ${this.parseEval(args[0], 'boolean')} THEN ${this.parseEval(args[1], type)} ELSE ${this.parseEval(args[2], type)} END)` - }, - $ifNull: (args) => { - const type = args.map(this.getLiteralType).find(x => x) ?? 'text' - return `coalesce(${args.map(arg => this.parseEval(arg, type)).join(', ')})` - }, - - $regex: ([key, value]) => `${this.parseEval(key)} ~ ${this.parseEval(value)}`, - - // number - $add: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' + ')})`, - $multiply: (args) => `(${args.map(arg => this.parseEval(arg, 'double precision')).join(' * ')})`, - $modulo: ([left, right]) => { - const dividend = this.parseEval(left, 'double precision'), divisor = this.parseEval(right, 'double precision') - return `${dividend} - (${divisor} * floor(${dividend} / ${divisor}))` - }, - $log: ([left, right]) => isNullable(right) - ? `ln(${this.parseEval(left, 'double precision')})` - : `ln(${this.parseEval(left, 'double precision')}) / ln(${this.parseEval(right, 'double precision')})`, - $random: () => `random()`, - - $eq: this.binary('=', 'text'), - - $number: (arg) => { - const value = this.parseEval(arg) - const res = this.state.sqlType === 'raw' ? `${value}::double precision` - : `extract(epoch from ${value})::bigint` - this.state.sqlType = 'raw' - return `coalesce(${res}, 0)` - }, - - $sum: (expr) => this.createAggr(expr, value => `coalesce(sum(${value})::double precision, 0)`, undefined, 'double precision'), - $avg: (expr) => this.createAggr(expr, value => `avg(${value})::double precision`, undefined, 'double precision'), - $min: (expr) => this.createAggr(expr, value => `min(${value})`, undefined, 'double precision'), - $max: (expr) => this.createAggr(expr, value => `max(${value})`, undefined, 'double precision'), - $count: (expr) => this.createAggr(expr, value => `count(distinct ${value})::integer`), - $length: (expr) => this.createAggr(expr, value => `count(${value})::integer`, value => { - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `${this.jsonLength(value)}` - } else { - this.state.sqlType = 'raw' - return `COALESCE(ARRAY_LENGTH(${value}, 1), 0)` - } - }), - - $concat: (args) => `${args.map(arg => this.parseEval(arg, 'text')).join('||')}`, - } - - this.define({ - types: ['time'], - dump: date => date ? (typeof date === 'string' ? date : formatTime(date)) : null, - load: str => { - if (isNullable(str)) return str - const date = new Date(0) - const parsed = timeRegex.exec(str) - if (!parsed) throw Error(`unexpected time value: ${str}`) - date.setHours(+parsed[1], +parsed[2], +parsed[3]) - return date - }, - }) - - this.define({ - types: ['list'], - dump: value => '{' + value.join(',') + '}', - load: value => value, - }) - } - - upsert(table: string) { - this.modifiedTable = table - } - - protected binary(operator: string, eltype: string = 'double precision') { - return ([left, right]) => { - const type = this.getLiteralType(left) ?? this.getLiteralType(right) ?? eltype - return `(${this.parseEval(left, type)} ${operator} ${this.parseEval(right, type)})` - } - } - - private getLiteralType(expr: any) { - if (typeof expr === 'string') return 'text' - else if (typeof expr === 'number') return 'double precision' - else if (typeof expr === 'string') return 'boolean' - } - - parseEval(expr: any, outtype: boolean | string = false): string { - this.state.sqlType = 'raw' - if (typeof expr === 'string' || typeof expr === 'number' || typeof expr === 'boolean' || expr instanceof Date || expr instanceof RegExp) { - return this.escape(expr) - } - return outtype ? this.jsonUnquote(this.parseEvalExpr(expr), false, typeof outtype === 'string' ? outtype : undefined) : this.parseEvalExpr(expr) - } - - protected createRegExpQuery(key: string, value: string | RegExp) { - return `${key} ~ ${this.escape(typeof value === 'string' ? value : value.source)}` - } - - protected createElementQuery(key: string, value: any) { - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return this.jsonContains(key, this.quote(JSON.stringify(value))) - } else { - return `${key} && ARRAY['${value}']::TEXT[]` - } - } - - protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string, eltype?: string) { - if (!this.state.group && !nonaggr) { - const value = this.parseEval(expr, false) - return `(select ${aggr(this.jsonUnquote(this.escapeId('value'), true, eltype))} from jsonb_array_elements(${value}) ${randomId()})` - } else { - return super.createAggr(expr, aggr, nonaggr) - } - } - - protected transformJsonField(obj: string, path: string) { - this.state.sqlType = 'json' - return `jsonb_extract_path(${obj}, ${path.slice(1).replace('.', ',')})` - } - - protected jsonLength(value: string) { - return `jsonb_array_length(${value})` - } - - protected jsonContains(obj: string, value: string) { - return `(${obj} @> ${value})` - } - - protected jsonUnquote(value: string, pure: boolean = false, type?: string) { - if (pure && type) return `(jsonb_build_object('v', ${value})->>'v')::${type}` - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `(jsonb_build_object('v', ${value})->>'v')::${type}` - } - return value - } - - protected jsonQuote(value: string, pure: boolean = false) { - if (pure) return `to_jsonb(${value})` - if (this.state.sqlType !== 'json') { - this.state.sqlType = 'json' - return `to_jsonb(${value})` - } - return value - } - - protected groupObject(fields: any) { - const parse = (expr) => { - const value = this.parseEval(expr, false) - return this.state.sqlType === 'json' ? `to_jsonb(${value})` : `${value}` - } - const res = `jsonb_build_object(` + Object.entries(fields).map(([key, expr]) => `'${key}', ${parse(expr)}`).join(',') + `)` - this.state.sqlType = 'json' - return res - } - - protected groupArray(value: string) { - this.state.sqlType = 'json' - return `coalesce(jsonb_agg(${value}), '[]'::jsonb)` - } - - protected parseSelection(sel: Selection) { - const { args: [expr], ref, table, tables } = sel - const restore = this.saveState({ tables }) - const inner = this.get(table as Selection, true, true) as string - const output = this.parseEval(expr, false) - restore() - if (!(sel.args[0] as any).$) { - return `(SELECT ${output} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''})` - } else { - return `(coalesce((SELECT ${this.groupArray(output)} AS value FROM ${inner} ${isBracketed(inner) ? ref : ''}), '[]'::jsonb))` - } - } - - escapeId = escapeId - - escapeKey(value: string) { - return `'${value}'` - } - - escape(value: any, field?: Field) { - if (value instanceof Date) { - value = formatTime(value) - } else if (value instanceof RegExp) { - value = value.source - } else if (!field && !!value && typeof value === 'object') { - return `${this.quote(JSON.stringify(value))}::jsonb` - } - return super.escape(value, field) - } - - toUpdateExpr(item: any, key: string, field?: Field, upsert?: boolean) { - const escaped = this.escapeId(key) - // update directly - if (key in item) { - if (!isEvalExpr(item[key]) && upsert) { - return `excluded.${escaped}` - } else if (isEvalExpr(item[key])) { - return this.parseEval(item[key]) - } else { - return this.escape(item[key], field) - } - } - - // prepare nested layout - const jsonInit = {} - for (const prop in item) { - if (!prop.startsWith(key + '.')) continue - const rest = prop.slice(key.length + 1).split('.') - if (rest.length === 1) continue - rest.reduce((obj, k) => obj[k] ??= {}, jsonInit) - } - - // update with json_set - const valueInit = this.modifiedTable ? `coalesce(${this.escapeId(this.modifiedTable)}.${escaped}, '{}')::jsonb` : `coalesce(${escaped}, '{}')::jsonb` - let value = valueInit - - // json_set cannot create deeply nested property when non-exist - // therefore we merge a layout to it - if (Object.keys(jsonInit).length !== 0) { - value = `(${value} || jsonb ${this.quote(JSON.stringify(jsonInit))})` - } - - for (const prop in item) { - if (!prop.startsWith(key + '.')) continue - const rest = prop.slice(key.length + 1).split('.') - value = `jsonb_set(${value}, '{${rest.map(key => `"${key}"`).join(',')}}', ${this.jsonQuote(this.parseEval(item[prop]), true)}, true)` - } - - if (value === valueInit) { - return this.modifiedTable ? `${this.escapeId(this.modifiedTable)}.${escaped}` : escaped - } else { - return value - } - } -} - export class PostgresDriver extends Driver { static name = 'postgres' diff --git a/packages/sqlite/src/builder.ts b/packages/sqlite/src/builder.ts new file mode 100644 index 00000000..ebef3f2b --- /dev/null +++ b/packages/sqlite/src/builder.ts @@ -0,0 +1,106 @@ +import { Builder, escapeId } from '@minatojs/sql-utils' +import { Dict, isNullable } from 'cosmokit' +import { Field, Model, randomId } from 'minato' + +export class SQLiteBuilder extends Builder { + protected escapeMap = { + "'": "''", + } + + constructor(tables?: Dict) { + super(tables) + + this.evalOperators.$if = (args) => `iif(${args.map(arg => this.parseEval(arg)).join(', ')})` + this.evalOperators.$concat = (args) => `(${args.map(arg => this.parseEval(arg)).join('||')})` + this.evalOperators.$modulo = ([left, right]) => `modulo(${this.parseEval(left)}, ${this.parseEval(right)})` + this.evalOperators.$log = ([left, right]) => isNullable(right) + ? `log(${this.parseEval(left)})` + : `log(${this.parseEval(left)}) / log(${this.parseEval(right)})` + this.evalOperators.$length = (expr) => this.createAggr(expr, value => `count(${value})`, value => { + if (this.state.sqlType === 'json') { + this.state.sqlType = 'raw' + return `${this.jsonLength(value)}` + } else { + this.state.sqlType = 'raw' + return `iif(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)` + } + }) + this.evalOperators.$number = (arg) => { + const value = this.parseEval(arg) + const res = this.state.sqlType === 'raw' ? `cast(${this.parseEval(arg)} as double)` + : `cast(${value} / 1000 as integer)` + this.state.sqlType = 'raw' + return `ifnull(${res}, 0)` + } + + this.define({ + types: ['boolean'], + dump: value => +value, + load: (value) => !!value, + }) + + this.define({ + types: ['json'], + dump: value => JSON.stringify(value), + load: (value, initial) => value ? JSON.parse(value) : initial, + }) + + this.define({ + types: ['list'], + dump: value => Array.isArray(value) ? value.join(',') : value, + load: (value) => value ? value.split(',') : [], + }) + + this.define({ + types: ['date', 'time', 'timestamp'], + dump: value => value === null ? null : +new Date(value), + load: (value) => value === null ? null : new Date(value), + }) + } + + escape(value: any, field?: Field) { + if (value instanceof Date) value = +value + else if (value instanceof RegExp) value = value.source + return super.escape(value, field) + } + + protected createElementQuery(key: string, value: any) { + if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { + return this.jsonContains(key, this.quote(JSON.stringify(value))) + } else { + return `(',' || ${key} || ',') LIKE ${this.escape('%,' + value + ',%')}` + } + } + + protected jsonLength(value: string) { + return `json_array_length(${value})` + } + + protected jsonContains(obj: string, value: string) { + return `json_array_contains(${obj}, ${value})` + } + + protected jsonUnquote(value: string, pure: boolean = false) { + return value + } + + protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string) { + if (!this.state.group && !nonaggr) { + const value = this.parseEval(expr, false) + return `(select ${aggr(escapeId('value'))} from json_each(${value}) ${randomId()})` + } else { + return super.createAggr(expr, aggr, nonaggr) + } + } + + protected groupArray(value: string) { + const res = this.state.sqlType === 'json' ? `('[' || group_concat(${value}) || ']')` : `('[' || group_concat(json_quote(${value})) || ']')` + this.state.sqlType = 'json' + return `ifnull(${res}, json_array())` + } + + protected transformJsonField(obj: string, path: string) { + this.state.sqlType = 'raw' + return `json_extract(${obj}, '$${path}')` + } +} diff --git a/packages/sqlite/src/index.ts b/packages/sqlite/src/index.ts index 5f136f37..cbc26c26 100644 --- a/packages/sqlite/src/index.ts +++ b/packages/sqlite/src/index.ts @@ -1,11 +1,12 @@ import { clone, deepEqual, Dict, difference, isNullable, makeArray } from 'cosmokit' -import { Driver, Eval, executeUpdate, Field, Model, randomId, Selection, z } from 'minato' -import { Builder, escapeId } from '@minatojs/sql-utils' +import { Driver, Eval, executeUpdate, Field, Selection, z } from 'minato' +import { escapeId } from '@minatojs/sql-utils' import { resolve } from 'node:path' import { readFile, writeFile } from 'node:fs/promises' import init from '@minatojs/sql.js' import enUS from './locales/en-US.yml' import zhCN from './locales/zh-CN.yml' +import { SQLiteBuilder } from './builder' function getTypeDef({ type }: Field) { switch (type) { @@ -36,109 +37,6 @@ export interface SQLiteFieldInfo { pk: boolean } -class SQLiteBuilder extends Builder { - protected escapeMap = { - "'": "''", - } - - constructor(tables?: Dict) { - super(tables) - - this.evalOperators.$if = (args) => `iif(${args.map(arg => this.parseEval(arg)).join(', ')})` - this.evalOperators.$concat = (args) => `(${args.map(arg => this.parseEval(arg)).join('||')})` - this.evalOperators.$modulo = ([left, right]) => `modulo(${this.parseEval(left)}, ${this.parseEval(right)})` - this.evalOperators.$log = ([left, right]) => isNullable(right) - ? `log(${this.parseEval(left)})` - : `log(${this.parseEval(left)}) / log(${this.parseEval(right)})` - this.evalOperators.$length = (expr) => this.createAggr(expr, value => `count(${value})`, value => { - if (this.state.sqlType === 'json') { - this.state.sqlType = 'raw' - return `${this.jsonLength(value)}` - } else { - this.state.sqlType = 'raw' - return `iif(${value}, LENGTH(${value}) - LENGTH(REPLACE(${value}, ${this.escape(',')}, ${this.escape('')})) + 1, 0)` - } - }) - this.evalOperators.$number = (arg) => { - const value = this.parseEval(arg) - const res = this.state.sqlType === 'raw' ? `cast(${this.parseEval(arg)} as double)` - : `cast(${value} / 1000 as integer)` - this.state.sqlType = 'raw' - return `ifnull(${res}, 0)` - } - - this.define({ - types: ['boolean'], - dump: value => +value, - load: (value) => !!value, - }) - - this.define({ - types: ['json'], - dump: value => JSON.stringify(value), - load: (value, initial) => value ? JSON.parse(value) : initial, - }) - - this.define({ - types: ['list'], - dump: value => Array.isArray(value) ? value.join(',') : value, - load: (value) => value ? value.split(',') : [], - }) - - this.define({ - types: ['date', 'time', 'timestamp'], - dump: value => value === null ? null : +new Date(value), - load: (value) => value === null ? null : new Date(value), - }) - } - - escape(value: any, field?: Field) { - if (value instanceof Date) value = +value - else if (value instanceof RegExp) value = value.source - return super.escape(value, field) - } - - protected createElementQuery(key: string, value: any) { - if (this.state.sqlTypes?.[this.unescapeId(key)] === 'json') { - return this.jsonContains(key, this.quote(JSON.stringify(value))) - } else { - return `(',' || ${key} || ',') LIKE ${this.escape('%,' + value + ',%')}` - } - } - - protected jsonLength(value: string) { - return `json_array_length(${value})` - } - - protected jsonContains(obj: string, value: string) { - return `json_array_contains(${obj}, ${value})` - } - - protected jsonUnquote(value: string, pure: boolean = false) { - return value - } - - protected createAggr(expr: any, aggr: (value: string) => string, nonaggr?: (value: string) => string) { - if (!this.state.group && !nonaggr) { - const value = this.parseEval(expr, false) - return `(select ${aggr(escapeId('value'))} from json_each(${value}) ${randomId()})` - } else { - return super.createAggr(expr, aggr, nonaggr) - } - } - - protected groupArray(value: string) { - const res = this.state.sqlType === 'json' ? `('[' || group_concat(${value}) || ']')` : `('[' || group_concat(json_quote(${value})) || ']')` - this.state.sqlType = 'json' - return `ifnull(${res}, json_array())` - } - - protected transformJsonField(obj: string, path: string) { - this.state.sqlType = 'raw' - return `json_extract(${obj}, '$${path}')` - } -} - export class SQLiteDriver extends Driver { static name = 'sqlite'