Skip to content

Commit

Permalink
feat(minato): support subquery in set (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hieuzest authored Apr 20, 2024
1 parent dd32785 commit 76316a2
Show file tree
Hide file tree
Showing 16 changed files with 173 additions and 65 deletions.
6 changes: 3 additions & 3 deletions packages/core/src/database.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Dict, makeArray, mapValues, MaybeArray, omit, valueMap } from 'cosmokit'
import { defineProperty, Dict, makeArray, mapValues, MaybeArray, omit } from 'cosmokit'
import { Context, Service, Spread } from 'cordis'
import { FlatKeys, FlatPick, Indexable, Keys, randomId, Row, unravel } from './utils.ts'
import { Selection } from './selection.ts'
Expand Down Expand Up @@ -117,7 +117,7 @@ export class Database<S = {}, N = {}, C extends Context = Context> extends Servi
})
model.extend(fields, config)
if (makeArray(model.primary).every(key => key in fields)) {
model.ctx = this[Context.origin]
defineProperty(model, 'ctx', this[Context.origin])
}
this.prepareTasks[name] = this.prepare(name)
;(this.ctx as Context).emit('model', name)
Expand Down Expand Up @@ -252,7 +252,7 @@ export class Database<S = {}, N = {}, C extends Context = Context> extends Servi
if (Array.isArray(oldTables)) {
tables = Object.fromEntries(oldTables.map((name) => [name, this.select(name)]))
}
const sels = valueMap(tables, (t: TableLike<S>) => {
const sels = mapValues(tables, (t: TableLike<S>) => {
return typeof t === 'string' ? this.select(t) : t
})
if (Object.keys(sels).length === 0) throw new Error('no tables to join')
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/driver.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Awaitable, Dict, remove, valueMap } from 'cosmokit'
import { Awaitable, Dict, mapValues, remove } from 'cosmokit'
import { Context, Logger } from 'cordis'
import { Eval, Update } from './eval.ts'
import { Direction, Modifier, Selection } from './selection.ts'
Expand Down Expand Up @@ -95,7 +95,7 @@ export abstract class Driver<T = any, C extends Context = Context> {
if (table instanceof Selection) {
if (!table.args[0].fields) return table.model
const model = new Model('temp')
model.fields = valueMap(table.args[0].fields, (expr, key) => ({
model.fields = mapValues(table.args[0].fields, (expr, key) => ({
type: Type.fromTerm(expr),
}))
return model
Expand Down
36 changes: 24 additions & 12 deletions packages/core/src/eval.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { defineProperty, isNullable, valueMap } from 'cosmokit'
import { defineProperty, isNullable, mapValues } from 'cosmokit'
import { Comparable, Flatten, isComparable, makeRegExp, Row } from './utils.ts'
import { Type } from './type.ts'
import { Field } from './model.ts'
Expand All @@ -7,6 +7,18 @@ export function isEvalExpr(value: any): value is Eval.Expr {
return value && Object.keys(value).some(key => key.startsWith('$'))
}

export function hasSubquery(value: any): boolean {
if (!isEvalExpr(value)) return false
return Object.entries(value).filter(([k]) => k.startsWith('$')).some(([k, v]) => {
if (isNullable(v) || isComparable(v)) return false
if (k === '$exec') return true
if (isEvalExpr(v)) return hasSubquery(v)
if (Array.isArray(v)) return v.some(x => hasSubquery(x))
if (typeof v === 'object') return Object.values(v).some(x => hasSubquery(x))
return false
})
}

export type Uneval<U, A extends boolean> =
| U extends number ? Eval.Term<number, A>
: U extends string ? Eval.Term<string, A>
Expand Down Expand Up @@ -41,9 +53,9 @@ export namespace Eval {
export type Binary<S, R> = <T extends S, A extends boolean>(x: Term<T, A>, y: Term<T, A>) => Expr<R, A>
export type Multi<S, R> = <T extends S, A extends boolean>(...args: Term<T, A>[]) => Expr<R, A>

export interface Aggr<S, R> {
<T extends S>(value: Term<T, false>): Expr<R, true>
<T extends S, A extends boolean>(value: Array<T, A>): Expr<R, A>
export interface Aggr<S> {
<T extends S>(value: Term<T, false>): Expr<T, true>
<T extends S, A extends boolean>(value: Array<T, A>): Expr<T, A>
}

export interface Branch<T, A extends boolean> {
Expand Down Expand Up @@ -105,14 +117,14 @@ export namespace Eval {
not: Unary<boolean, boolean>

// typecast
literal<T>(value: T, type?: Field.Type<T> | Field.NewType<T> | string): Expr<T, false>
literal<T>(value: T, type?: Type<T> | Field.Type<T> | Field.NewType<T> | string): Expr<T, false>
number: Unary<any, number>

// aggregation / json
sum: Aggr<number, number>
avg: Aggr<number, number>
max: Aggr<number, number> & Aggr<Date, Date>
min: Aggr<number, number> & Aggr<Date, Date>
sum: Aggr<number>
avg: Aggr<number>
max: Aggr<Comparable>
min: Aggr<Comparable>
count(value: Any<false>): Expr<number, true>
length(value: Any<false>): Expr<number, true>
size<A extends boolean>(value: (Any | Expr<Any, A>)[] | Expr<Any[], A>): Expr<number, A>
Expand Down Expand Up @@ -246,7 +258,7 @@ defineProperty(Eval, 'length', unary('length', (expr, table) => Array.isArray(ta
? table.map(data => executeAggr(expr, data)).length
: Array.from(executeEval(table, expr)).length, Type.Number))

operators.$object = (field, table) => valueMap(field, value => executeAggr(value, table))
operators.$object = (field, table) => mapValues(field, value => executeAggr(value, table))
Eval.object = (fields: any) => {
if (fields.$model) {
const modelFields: [string, Field][] = Object.entries(fields.$model.fields)
Expand All @@ -255,9 +267,9 @@ Eval.object = (fields: any) => {
.filter(([, field]) => !field.deprecated)
.filter(([path]) => path.startsWith(prefix))
.map(([k]) => [k.slice(prefix.length), fields[k.slice(prefix.length)]]))
return Eval('object', fields, Type.Object(valueMap(fields, (value) => Type.fromTerm(value))))
return Eval('object', fields, Type.Object(mapValues(fields, (value) => Type.fromTerm(value))))
}
return Eval('object', fields, Type.Object(valueMap(fields, (value) => Type.fromTerm(value)))) as any
return Eval('object', fields, Type.Object(mapValues(fields, (value) => Type.fromTerm(value)))) as any
}

Eval.array = unary('array', (expr, table) => Array.isArray(table)
Expand Down
4 changes: 2 additions & 2 deletions packages/core/src/model.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Binary, clone, isNullable, makeArray, MaybeArray, valueMap } from 'cosmokit'
import { Binary, clone, isNullable, makeArray, mapValues, MaybeArray } from 'cosmokit'
import { Context } from 'cordis'
import { Eval, isEvalExpr } from './eval.ts'
import { Flatten, Keys, unravel } from './utils.ts'
Expand Down Expand Up @@ -312,7 +312,7 @@ export class Model<S = any> {
getType(): Type<S>
getType(key: string): Type | undefined
getType(key?: string): Type | undefined {
this.type ??= Type.Object(valueMap(this.fields!, field => Type.fromField(field!))) as any
this.type ??= Type.Object(mapValues(this.fields!, field => Type.fromField(field!))) as any
return key ? Type.getInner(this.type, key) : this.type
}
}
4 changes: 2 additions & 2 deletions packages/core/src/selection.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { defineProperty, Dict, filterKeys, valueMap } from 'cosmokit'
import { defineProperty, Dict, filterKeys, mapValues } from 'cosmokit'
import { Driver } from './driver.ts'
import { Eval, executeEval } from './eval.ts'
import { Model } from './model.ts'
Expand Down Expand Up @@ -119,7 +119,7 @@ class Executable<S = any, T = any> {
})
return Object.fromEntries(entries)
} else {
return valueMap(fields, field => this.resolveField(field))
return mapValues(fields, field => this.resolveField(field))
}
}

Expand Down
1 change: 1 addition & 0 deletions packages/core/src/type.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ export namespace Type {
else if (typeof value === 'number') return Number as any
else if (typeof value === 'string') return String as any
else if (typeof value === 'boolean') return Boolean as any
else if (typeof value === 'bigint') return fromField('bigint' as any)
else if (value instanceof Date) return fromField('timestamp' as any)
else if (Binary.is(value)) return fromField('binary' as any)
else if (globalThis.Array.isArray(value)) return Array(value.length ? fromPrimitive(value[0]) : undefined) as any
Expand Down
4 changes: 2 additions & 2 deletions packages/memory/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { clone, Dict, makeArray, noop, omit, pick, valueMap } from 'cosmokit'
import { clone, Dict, makeArray, mapValues, noop, omit, pick } from 'cosmokit'
import { Driver, Eval, executeEval, executeQuery, executeSort, executeUpdate, RuntimeError, Selection, z } from 'minato'

export class MemoryDriver extends Driver<MemoryDriver.Config> {
Expand Down Expand Up @@ -61,7 +61,7 @@ export class MemoryDriver extends Driver<MemoryDriver.Config> {
}
let index = row
if (fields) {
index = valueMap(groupFields!, (expr) => executeEval({ ...env, [ref]: row }, expr))
index = mapValues(groupFields!, (expr) => executeEval({ ...env, [ref]: row }, expr))
}
let branch = branches.find((branch) => {
if (!group || !groupFields) return false
Expand Down
25 changes: 17 additions & 8 deletions packages/mongo/src/builder.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Dict, isNullable, mapValues, valueMap } from 'cosmokit'
import { Dict, isNullable, mapValues } from 'cosmokit'
import { Driver, Eval, isComparable, isEvalExpr, Model, Query, Selection, Type, unravel } from 'minato'
import { Filter, FilterOperators, ObjectId } from 'mongodb'
import MongoDriver from '.'
Expand Down Expand Up @@ -113,7 +113,7 @@ export class Builder {
},
$if: (arg, group) => ({ $cond: arg.map(val => this.eval(val, group)) }),

$object: (arg, group) => valueMap(arg as any, x => this.transformEvalExpr(x)),
$object: (arg, group) => mapValues(arg as any, x => this.transformEvalExpr(x)),

$regex: (arg, group) => ({ $regexMatch: { input: this.eval(arg[0], group), regex: this.eval(arg[1], group) } }),

Expand All @@ -128,8 +128,7 @@ export class Builder {
$random: (arg, group) => ({ $rand: {} }),

$literal: (arg, group) => {
const converter = this.driver.types[arg[1] as any]
return converter ? converter.dump(arg[0]) : arg[0]
return { $literal: this.dump(arg[0], arg[1] ? Type.fromField(arg[1]) : undefined) }
},
$number: (arg, group) => {
const value = this.eval(arg, group)
Expand Down Expand Up @@ -209,7 +208,7 @@ export class Builder {
if (this.evalOperators[key]) {
return this.evalOperators[key](expr[key], group)
} else if (key?.startsWith('$') && Eval[key.slice(1)]) {
return valueMap(expr, (value) => {
return mapValues(expr, (value) => {
if (Array.isArray(value)) {
return value.map(val => this.eval(val, group))
} else {
Expand Down Expand Up @@ -362,7 +361,7 @@ export class Builder {
stages.push(...this.flushLookups(), ...groupStages, { $project })
$group['_id'] = unravel($group['_id'])
} else if (fields) {
const $project = valueMap(fields, (expr) => this.eval(expr))
const $project = mapValues(fields, (expr) => this.eval(expr))
$project._id = 0
stages.push(...this.flushLookups(), { $project })
} else {
Expand All @@ -381,8 +380,8 @@ export class Builder {
return predecessor.select(sel)
}

public select(sel: Selection.Immutable) {
const { table, query } = sel
public select(sel: Selection.Immutable, update?: any) {
const { model, table, query } = sel
if (typeof table === 'string') {
this.table = table
this.refVirtualKeys[sel.ref] = this.virtualKey = (sel.driver as MongoDriver).getVirtualKey(table)!
Expand Down Expand Up @@ -438,6 +437,16 @@ export class Builder {
this.pipeline.push({ $group }, ...this.flushLookups(), { $project })
}
this.evalKey = $
} else if (sel.type === 'set') {
const $set = mapValues(update, (expr, key) => this.eval(isEvalExpr(expr) ? expr : Eval.literal(expr, model.getType(key))))
this.pipeline.push(...this.flushLookups(), { $set }, {
$merge: {
into: table,
on: '_id',
whenMatched: 'replace',
whenNotMatched: 'discard',
},
})
}
return this
}
Expand Down
59 changes: 37 additions & 22 deletions packages/mongo/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { BSONType, ClientSession, Collection, Db, IndexDescription, Long, MongoClient, MongoClientOptions, MongoError } from 'mongodb'
import { Binary, Dict, isNullable, makeArray, mapValues, noop, omit, pick } from 'cosmokit'
import { Driver, Eval, executeUpdate, Query, RuntimeError, Selection, z } from 'minato'
import { Driver, Eval, executeUpdate, hasSubquery, Query, RuntimeError, Selection, z } from 'minato'
import { URLSearchParams } from 'url'
import { Builder } from './builder'

Expand Down Expand Up @@ -302,17 +302,20 @@ export class MongoDriver extends Driver<MongoDriver.Config> {
async get(sel: Selection.Immutable) {
const transformer = new Builder(this, Object.keys(sel.tables)).select(sel)
if (!transformer) return []
this.logger.debug('%s %s', transformer.table, JSON.stringify(transformer.pipeline))
this.logPipeline(transformer.table, transformer.pipeline)
return this.db
.collection(transformer.table)
.aggregate(transformer.pipeline, { allowDiskUse: true, session: this.session })
.toArray().then(rows => rows.map(row => this.builder.load(row, sel.model)))
.toArray().then(rows => {
// console.dir(rows, { depth: 8 })
return rows.map(row => this.builder.load(row, sel.model))
})
}

async eval(sel: Selection.Immutable, expr: Eval.Expr) {
const transformer = new Builder(this, Object.keys(sel.tables)).select(sel)
if (!transformer) return
this.logger.debug('%s %s', transformer.table, JSON.stringify(transformer.pipeline))
this.logPipeline(transformer.table, transformer.pipeline)
const res = await this.db
.collection(transformer.table)
.aggregate(transformer.pipeline, { allowDiskUse: true, session: this.session })
Expand All @@ -322,25 +325,33 @@ export class MongoDriver extends Driver<MongoDriver.Config> {

async set(sel: Selection.Mutable, update: {}) {
const { query, table, model } = sel
const filter = this.transformQuery(sel, query, table)
if (!filter) return {}
const coll = this.db.collection(table)
if (hasSubquery(sel.query) || Object.values(update).some(x => hasSubquery(x))) {
const transformer = new Builder(this, Object.keys(sel.tables)).select(sel, update)!
await this.db.collection(transformer.table)
.aggregate(transformer.pipeline, { allowDiskUse: true, session: this.session })
.toArray()
return {} // result not available
} else {
const filter = this.transformQuery(sel, query, table)
if (!filter) return {}
const coll = this.db.collection(table)

const transformer = new Builder(this, Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.')
const $set = this.builder.formatUpdateAggr(model.getType(), mapValues(this.builder.dump(update, model),
(value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : transformer.eval(value)))
const $unset = Object.entries($set)
.filter(([_, value]) => typeof value === 'object')
.map(([key, _]) => key)
const preset = Object.fromEntries(transformer.walkedKeys.map(key => [tempKey + '.' + key, '$' + key]))

const result = await coll.updateMany(filter, [
...transformer.walkedKeys.length ? [{ $set: preset }] : [],
...$unset.length ? [{ $unset }] : [],
{ $set },
...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [],
], { session: this.session })
return { matched: result.matchedCount, modified: result.modifiedCount }
const transformer = new Builder(this, Object.keys(sel.tables), this.getVirtualKey(table), '$' + tempKey + '.')
const $set = this.builder.formatUpdateAggr(model.getType(), mapValues(this.builder.dump(update, model),
(value: any) => typeof value === 'string' && value.startsWith('$') ? { $literal: value } : transformer.eval(value)))
const $unset = Object.entries($set)
.filter(([_, value]) => typeof value === 'object')
.map(([key, _]) => key)
const preset = Object.fromEntries(transformer.walkedKeys.map(key => [tempKey + '.' + key, '$' + key]))

const result = await coll.updateMany(filter, [
...transformer.walkedKeys.length ? [{ $set: preset }] : [],
...$unset.length ? [{ $unset }] : [],
{ $set },
...transformer.walkedKeys.length ? [{ $unset: [tempKey] }] : [],
], { session: this.session })
return { matched: result.matchedCount, modified: result.modifiedCount }
}
}

async remove(sel: Selection.Mutable) {
Expand Down Expand Up @@ -476,6 +487,10 @@ See https://www.mongodb.com/docs/manual/tutorial/convert-standalone-to-replica-s
await callback(undefined)
}
}

logPipeline(table: string, pipeline: any) {
this.logger.debug('%s %s', table, JSON.stringify(pipeline, (_, value) => typeof value === 'bigint' ? `${value}n` : value))
}
}

export namespace MongoDriver {
Expand Down
3 changes: 1 addition & 2 deletions packages/mysql/src/builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ export class MySQLBuilder extends Builder {

protected encode(value: string, encoded: boolean, pure: boolean = false, type?: Type) {
return this.asEncoded(encoded === this.isEncoded() && !pure ? value : encoded
? (this.compat.maria ? `json_extract(json_object('v', ${this.transform(value, type, 'encode')}), '$.v')`
: `cast(${this.transform(value, type, 'encode')} as json)`)
? `json_extract(json_object('v', ${this.transform(value, type, 'encode')}), '$.v')`
: this.transform(`json_unquote(${value})`, type, 'decode'), pure ? undefined : encoded)
}

Expand Down
3 changes: 2 additions & 1 deletion packages/mysql/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,8 @@ INSERT INTO mtt VALUES(json_extract(j, concat('$[', i, ']'))); SET i=i+1; END WH
const escaped = escapeId(field)
return `${escaped} = ${builder.toUpdateExpr(data, field, fields[field], false)}`
}).join(', ')
const result = await this.query(`UPDATE ${escapeId(table)} ${ref} SET ${update} WHERE ${filter}`)
const sql = [...builder.prequeries, `UPDATE ${escapeId(table)} ${ref} SET ${update} WHERE ${filter}`].join('; ')
const result = await this.query(sql)
return { matched: result.affectedRows, modified: result.changedRows }
}

Expand Down
Loading

0 comments on commit 76316a2

Please sign in to comment.