Skip to content

Commit

Permalink
feat(orm): support db.aggregate()
Browse files Browse the repository at this point in the history
  • Loading branch information
shigma committed Aug 21, 2021
1 parent 9172e95 commit ac2fac6
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 20 deletions.
4 changes: 2 additions & 2 deletions packages/koishi-core/src/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ export namespace Query {
return modifier || {}
}

type Projection<T extends TableType, K extends string> = Record<K, Evaluation.Aggregation<Tables[T]>>
type Projection<T extends TableType, K extends string> = Record<K, Eval.Aggregation<Tables[T]>>

export interface Database {
drop(table?: TableType): Promise<void>
Expand All @@ -224,7 +224,7 @@ export namespace Query {
}
}

export namespace Evaluation {
export namespace Eval {
export type Numeric<T = any, U = never> = U | number | Keys<T, number> | NumericExpr<Numeric<T, U>>

export interface NumericExpr<N = Numeric> {
Expand Down
48 changes: 47 additions & 1 deletion packages/koishi-test-utils/src/memory.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { Tables, TableType, Query, App, Database, User, Channel } from 'koishi-core'
import { Tables, TableType, Query, App, Database, User, Channel, Eval } from 'koishi-core'
import { clone, pick } from 'koishi-utils'

declare module 'koishi-core' {
Expand Down Expand Up @@ -106,6 +106,46 @@ function executeQuery(query: Query.Expr, data: any): boolean {
})
}

function executeNumericExpr<U>(expr: Eval.NumericExpr<U>, data: any, execute: (expr: U, data: any) => number = executeNumeric) {
if ('$add' in expr) {
return expr.$add.reduce<number>((prev, curr) => prev + execute(curr, data), 0)
} else if ('$multiply' in expr) {
return expr.$multiply.reduce<number>((prev, curr) => prev * execute(curr, data), 1)
} else if ('$subtract' in expr) {
return execute(expr.$subtract[0], data) - execute(expr.$subtract[1], data)
} else if ('$divide' in expr) {
return execute(expr.$divide[0], data) / execute(expr.$divide[1], data)
}
}

function executeNumeric(expr: Eval.Numeric, data: any): number {
if (typeof expr === 'string') {
return data[expr]
} else if (typeof expr === 'number') {
return expr
} else {
return executeNumericExpr(expr, data)
}
}

function executeAggregation(expr: Eval.Aggregation, table: any[]): number {
if (typeof expr === 'number') {
return expr
} else if ('$sum' in expr) {
return table.reduce((prev, curr) => prev + executeNumeric(expr.$sum, curr), 0)
} else if ('$avg' in expr) {
return table.reduce((prev, curr) => prev + executeNumeric(expr.$avg, curr), 0) / table.length
} else if ('$min' in expr) {
return Math.min(...table.map(data => executeNumeric(expr.$min, data)))
} else if ('$max' in expr) {
return Math.max(...table.map(data => executeNumeric(expr.$max, data)))
} else if ('$count' in expr) {
return new Set(table.map(data => executeNumeric(expr.$count, data))).size
} else {
return executeNumericExpr(expr as never, table, executeAggregation)
}
}

Database.extend(MemoryDatabase, {
async drop(name) {
if (name) {
Expand Down Expand Up @@ -149,6 +189,12 @@ Database.extend(MemoryDatabase, {
}
},

async aggregate(name, fields, query) {
const expr = Query.resolve(name, query)
const table = this.$table(name).filter(row => executeQuery(expr, row))
return Object.fromEntries(Object.entries(fields).map(([key, expr]) => [key, executeAggregation(expr, table)]))
},

async getUser(type, id, fields) {
if (Array.isArray(id)) {
return this.get('user', { [type]: id }, fields) as any
Expand Down
57 changes: 41 additions & 16 deletions packages/plugin-mongo/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import MongoDatabase, { Config } from './database'
import { User, Tables, Database, Context, Channel, Random, pick, omit, TableType, Query } from 'koishi-core'
import { User, Tables, Database, Context, Channel, Random, pick, omit, TableType, Query, Eval } from 'koishi-core'
import { QuerySelector } from 'mongodb'

export * from './database'
Expand Down Expand Up @@ -102,22 +102,22 @@ function transformFieldQuery(query: Query.FieldQuery, key: string) {
return result
}

function createFilter<T extends TableType>(name: T, _query: Query<T>) {
function transformQuery(query: Query.Expr) {
const filter = {}
for (const key in query) {
const value = query[key]
if (key === '$and' || key === '$or') {
filter[key] = value.map(transformQuery)
} else if (key === '$not') {
filter[key] = transformQuery(value)
} else {
filter[key] = transformFieldQuery(value, key)
}
function transformQuery(query: Query.Expr) {
const filter = {}
for (const key in query) {
const value = query[key]
if (key === '$and' || key === '$or') {
filter[key] = value.map(transformQuery)
} else if (key === '$not') {
filter[key] = transformQuery(value)
} else {
filter[key] = transformFieldQuery(value, key)
}
return filter
}
return filter
}

function createFilter<T extends TableType>(name: T, _query: Query<T>) {
const filter = transformQuery(Query.resolve(name, _query))
const { primary } = Tables.config[name]
if (filter[primary]) {
Expand All @@ -132,6 +132,22 @@ function createFilter<T extends TableType>(name: T, _query: Query<T>) {
return filter
}

function transformEval(expr: Eval.Numeric | Eval.Aggregation) {
if (typeof expr === 'string') {
return '$' + expr
} else if (typeof expr === 'number') {
return expr
}

return Object.fromEntries(Object.entries(expr).map(([key, value]) => {
if (Array.isArray(value)) {
return [key, value.map(transformEval)]
} else {
return [key, transformEval(value)]
}
}))
}

function getFallbackType({ fields, primary }: Tables.Config) {
const { type } = fields[primary]
return Tables.Field.string.includes(type) ? 'random' : 'incremental'
Expand All @@ -149,7 +165,6 @@ Database.extend(MongoDatabase, {

async get(name, query, modifier) {
const filter = createFilter(name, query)
if (!filter) return []
let cursor = this.db.collection(name).find(filter)
const { fields, limit, offset = 0 } = Query.resolveModifier(modifier)
if (fields) cursor = cursor.project(Object.fromEntries(fields.map(key => [key, 1])))
Expand All @@ -167,7 +182,6 @@ Database.extend(MongoDatabase, {

async remove(name, query) {
const filter = createFilter(name, query)
if (!filter) return
await this.db.collection(name).deleteMany(filter)
},

Expand Down Expand Up @@ -198,6 +212,17 @@ Database.extend(MongoDatabase, {
await bulk.execute()
},

async aggregate(name, fields, query) {
const $match = createFilter(name, query)
const [data] = await this.db.collection(name).aggregate([{ $match }, {
$group: {
_id: 1,
...Object.fromEntries(Object.entries(fields).map(([key, value]) => [key, transformEval(value)])),
},
}]).toArray()
return data
},

async getUser(type, id, modifier) {
const { fields } = Query.resolveModifier(modifier)
const applyDefault = (user: User) => ({
Expand Down
38 changes: 37 additions & 1 deletion packages/plugin-mysql/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import MysqlDatabase, { Config } from './database'
import { User, Channel, Database, Context, Query, Evaluation } from 'koishi-core'
import { User, Channel, Database, Context, Query, Eval } from 'koishi-core'
import { difference } from 'koishi-utils'
import { OkPacket, escapeId, escape } from 'mysql'
import * as Koishi from 'koishi-core'
Expand Down Expand Up @@ -117,6 +117,32 @@ function parseQuery(query: Query.Expr) {
return conditions.join(' && ')
}

function parseNumeric(expr: Eval.Aggregation | Eval.Numeric) {
if (typeof expr === 'string') {
return escapeId(expr)
} else if (typeof expr === 'number') {
return escape(expr)
} else if ('$sum' in expr) {
return `ifnull(sum(${parseNumeric(expr.$sum)}), 0)`
} else if ('$avg' in expr) {
return `avg(${parseNumeric(expr.$avg)})`
} else if ('$min' in expr) {
return `min(${parseNumeric(expr.$min)})`
} else if ('$max' in expr) {
return `max(${parseNumeric(expr.$max)})`
} else if ('$count' in expr) {
return `count(${parseNumeric(expr.$count)})`
} else if ('$add' in expr) {
return expr.$add.map(parseNumeric).join(' + ')
} else if ('$multiply' in expr) {
return expr.$multiply.map(parseNumeric).join(' * ')
} else if ('$subtract' in expr) {
return expr.$subtract.map(parseNumeric).join(' - ')
} else if ('$divide' in expr) {
return expr.$divide.map(parseNumeric).join(' / ')
}
}

Database.extend(MysqlDatabase, {
async drop(name) {
if (name) {
Expand Down Expand Up @@ -171,6 +197,16 @@ Database.extend(MysqlDatabase, {
)
},

async aggregate(name, fields, query) {
const keys = Object.keys(fields)
if (!keys.length) return {}

const filter = parseQuery(Query.resolve(name, query))
const exprs = keys.map(key => `${parseNumeric(fields[key])} AS ${escapeId(key)}`).join(', ')
const [data] = await this.query(`SELECT ${exprs} FROM ${name} WHERE ${filter}`)
return data
},

async getUser(type, id, modifier) {
const { fields } = Query.resolveModifier(modifier)
if (fields && !fields.length) {
Expand Down

0 comments on commit ac2fac6

Please sign in to comment.