From afc647abbdcdaac12fd32cb3897835928e98f515 Mon Sep 17 00:00:00 2001 From: Shigma <1700011071@pku.edu.cn> Date: Tue, 25 Aug 2020 17:46:03 +0800 Subject: [PATCH] teach: fix redirect breaks throttle & preventLoop --- .../plugin-teach/src/plugins/preventLoop.ts | 1 + packages/plugin-teach/src/plugins/throttle.ts | 6 ++- packages/plugin-teach/src/receiver.ts | 18 +++++--- packages/plugin-teach/tests/basic.spec.ts | 45 +++++++++++++++++-- 4 files changed, 58 insertions(+), 12 deletions(-) diff --git a/packages/plugin-teach/src/plugins/preventLoop.ts b/packages/plugin-teach/src/plugins/preventLoop.ts index 948dac50f1..e2bd410736 100644 --- a/packages/plugin-teach/src/plugins/preventLoop.ts +++ b/packages/plugin-teach/src/plugins/preventLoop.ts @@ -35,6 +35,7 @@ export default function apply(ctx: Context, config: Dialogue.Config) { }) ctx.on('dialogue/receive', (state) => { + if (state.session._redirected) return const timestamp = Date.now() for (const { participants, length, debounce } of preventLoopConfig) { if (state.initiators.length < length) break diff --git a/packages/plugin-teach/src/plugins/throttle.ts b/packages/plugin-teach/src/plugins/throttle.ts index a07b83fb37..2514c35c11 100644 --- a/packages/plugin-teach/src/plugins/throttle.ts +++ b/packages/plugin-teach/src/plugins/throttle.ts @@ -35,13 +35,15 @@ export default function apply(ctx: Context, config: Dialogue.Config) { state.counters = { ...counters } }) - ctx.on('dialogue/receive', ({ counters }) => { + ctx.on('dialogue/receive', ({ counters, session }) => { + if (session._redirected) return for (const interval in counters) { if (counters[interval] <= 0) return true } }) - ctx.on('dialogue/before-send', ({ counters }) => { + ctx.on('dialogue/before-send', ({ counters, session }) => { + if (session._redirected) return for (const { interval } of throttleConfig) { counters[interval]-- setTimeout(() => counters[interval]++, interval) diff --git a/packages/plugin-teach/src/receiver.ts b/packages/plugin-teach/src/receiver.ts index 6a58702086..c58b37260d 100644 --- a/packages/plugin-teach/src/receiver.ts +++ b/packages/plugin-teach/src/receiver.ts @@ -2,6 +2,12 @@ import { Context, User, Session, NextFunction, Command } from 'koishi-core' import { CQCode, simplify, noop, escapeRegExp } from 'koishi-utils' import { Dialogue, DialogueTest } from './utils' +declare module 'koishi-core/dist/app' { + interface App { + _dialogueStates: Record + } +} + declare module 'koishi-core/dist/context' { interface EventMap { 'dialogue/state'(state: SessionState): void @@ -59,8 +65,6 @@ export interface SessionState { isSearch?: boolean } -const states: Record = {} - export function escapeAnswer(message: string) { return message.replace(/%/g, '@@__PLACEHOLDER__@@') } @@ -70,11 +74,11 @@ export function unescapeAnswer(message: string) { } Context.prototype.getSessionState = function (session) { - const { groupId, anonymous, userId } = session - if (!states[groupId]) { - this.emit('dialogue/state', states[groupId] = { groupId } as SessionState) + const { groupId, anonymous, userId, $app } = session + if (!$app._dialogueStates[groupId]) { + this.emit('dialogue/state', $app._dialogueStates[groupId] = { groupId } as SessionState) } - const state = Object.create(states[groupId]) + const state = Object.create($app._dialogueStates[groupId]) state.session = session state.userId = anonymous ? -anonymous.id : userId return state @@ -249,6 +253,8 @@ export default function (ctx: Context, config: Dialogue.Config) { const nicknames = Array.isArray(nickname) ? nickname : nickname ? [nickname] : [] const nicknameRE = new RegExp(`^((${nicknames.map(escapeRegExp).join('|')})[,,]?\\s*)+`) + ctx.app._dialogueStates = {} + config._stripQuestion = (source) => { source = prepareSource(source) const original = source diff --git a/packages/plugin-teach/tests/basic.spec.ts b/packages/plugin-teach/tests/basic.spec.ts index 0acd833fbf..828dc1c1bc 100644 --- a/packages/plugin-teach/tests/basic.spec.ts +++ b/packages/plugin-teach/tests/basic.spec.ts @@ -102,16 +102,18 @@ describe('Plugin Teach', () => { app.plugin(utils) - before(async () => { + async function start() { await app.start() await app.database.getUser(u2id, 2) await app.database.getUser(u3id, 3) await app.database.getUser(u4id, 4) await app.database.getGroup(g1id, app.selfId) await app.database.getGroup(g2id, app.selfId) - }) + } + + before(start) - return { app, u2, u3, u4, u2g1, u2g2, u3g1, u3g2, u4g1, u4g2 } + return { app, u2, u3, u4, u2g1, u2g2, u3g1, u3g2, u4g1, u4g2, start } } const DETAIL_HEAD = '编号为 1 的问答信息:\n问题:foo\n回答:bar\n' @@ -237,10 +239,45 @@ describe('Plugin Teach', () => { await u3g1.shouldHaveReply('#1 ~ %s:%{test}', '问答 1 已成功修改。') await u2g1.shouldHaveReply('foo', 'nick2:200') await u3g1.shouldHaveReply('#1 -s', '问答 1 已成功修改。') - await u3g1.shouldHaveReply('## foo', SEARCH_HEAD + '1. [代行] %s:%{test}') await u3g1.shouldHaveReply('#1 -w [CQ:at,qq=300]', '问答 1 已成功修改。') await u3g1.shouldHaveReply('#1', DETAIL_HEAD + '来源:user3 (300)\n回答中的指令由教学者代行。') + await u3g1.shouldHaveReply('## foo', SEARCH_HEAD + '1. [代行] %s:%{test}') await u2g1.shouldHaveReply('foo', 'nick2:300') }) }) + + describe('restriction', () => { + // make coverage happy + new App().plugin(teach, { throttle: [] }) + new App().plugin(teach, { preventLoop: [] }) + new App().plugin(teach, { preventLoop: 10 }) + + it('throttle', async () => { + const { u2g1, u3g1, u4g1, u4g2, start } = createEnvironment({ throttle: { interval: 1000, responses: 2 } }) + + await start() + await u3g1.shouldHaveReply('# baz bar', '问答已添加,编号为 1。') + await u3g1.shouldHaveReply('# foo => baz', '问答已添加,编号为 2。') + await u2g1.shouldHaveReply('foo', 'bar') + await u3g1.shouldHaveReply('foo', 'bar') + await u4g1.shouldHaveNoReply('foo') + await u4g2.shouldHaveReply('foo', 'bar') + }) + + it('preventLoop', async () => { + const { u2g1, u3g1, u4g1, start } = createEnvironment({ preventLoop: { length: 5, participants: 2 } }) + + await start() + await u3g1.shouldHaveReply('# baz bar', '问答已添加,编号为 1。') + await u3g1.shouldHaveReply('# foo => baz', '问答已添加,编号为 2。') + await u2g1.shouldHaveReply('foo', 'bar') + await u2g1.shouldHaveReply('foo', 'bar') + await u3g1.shouldHaveReply('foo', 'bar') + await u3g1.shouldHaveReply('foo', 'bar') + await u2g1.shouldHaveReply('foo', 'bar') + await u2g1.shouldHaveNoReply('foo') + await u3g1.shouldHaveNoReply('foo') + await u4g1.shouldHaveReply('foo', 'bar') + }) + }) })