Skip to content

Commit

Permalink
feat: add get method to store (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
gamemaker1 authored Oct 4, 2023
1 parent 102bbc2 commit f1880d9
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 83 deletions.
168 changes: 101 additions & 67 deletions source/lib.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,43 @@
import type {
Store,
IncrementResponse,
ClientRateLimitInfo,
Options as RateLimitConfiguration,
} from 'express-rate-limit'
import { type Options, type SendCommandFn } from './types.js'
import scripts from './scripts.js'
import type { Options, SendCommandFn, RedisReply } from './types.js'

/**
* Converts a string/number to a number.
*
* @param input {string | number | undefined} - The input to convert to a number.
*
* @return {number} - The parsed integer.
* @throws {Error} - Thrown if the string does not contain a valid number.
*/
const toInt = (input: string | number | boolean | undefined): number => {
if (typeof input === 'number') return input
return Number.parseInt((input ?? '').toString(), 10)
}

/**
* Parses the response from the script.
*
* Note that the responses returned by the `get` and `increment` scripts are
* the same, so this function can be used with both.
*/
const parseScriptResponse = (results: RedisReply): ClientRateLimitInfo => {
if (!Array.isArray(results))
throw new TypeError('Expected result to be array of values')
if (results.length !== 2)
throw new Error(`Expected 2 replies, got ${results.length}`)

const totalHits = toInt(results[0])
const timeToExpire = toInt(results[1])

const resetTime = new Date(Date.now() + timeToExpire)
return { totalHits, resetTime }
}

/**
* A `Store` for the `express-rate-limit` package that stores hit counts in
Expand All @@ -30,9 +64,11 @@ class RedisStore implements Store {
resetExpiryOnChange: boolean

/**
* Stores the loaded SHA1 of the LUA script for executing the increment operations.
* Stores the loaded SHA1s of the LUA scripts used for executing the increment
* and get key operations.
*/
loadedScriptSha1: Promise<string>
incrementScriptSha: Promise<string>
getScriptSha: Promise<string>

/**
* The number of milliseconds to remember that user's requests.
Expand All @@ -51,32 +87,18 @@ class RedisStore implements Store {

// So that the script loading can occur non-blocking, this will send
// the script to be loaded, and will capture the value within the
// promise return. This way, if increments start being called before
// promise return. This way, if increment/get start being called before
// the script has finished loading, it will wait until it is loaded
// before it continues.
this.loadedScriptSha1 = this.loadScript()
this.incrementScriptSha = this.loadIncrementScript()
this.getScriptSha = this.loadGetScript()
}

async loadScript(): Promise<string> {
const result = await this.sendCommand(
'SCRIPT',
'LOAD',
`
local totalHits = redis.call("INCR", KEYS[1])
local timeToExpire = redis.call("PTTL", KEYS[1])
if timeToExpire <= 0 or ARGV[1] == "1"
then
redis.call("PEXPIRE", KEYS[1], tonumber(ARGV[2]))
timeToExpire = tonumber(ARGV[2])
end
return { totalHits, timeToExpire }
`
// Ensure that code changes that affect whitespace do not affect
// the script contents.
.replaceAll(/^\s+/gm, '')
.trim(),
)
/**
* Loads the script used to increment a client's hit count.
*/
async loadIncrementScript(): Promise<string> {
const result = await this.sendCommand('SCRIPT', 'LOAD', scripts.increment)

if (typeof result !== 'string') {
throw new TypeError('unexpected reply from redis client')
Expand All @@ -86,30 +108,26 @@ class RedisStore implements Store {
}

/**
* Method to prefix the keys with the given text.
*
* @param key {string} - The key.
*
* @returns {string} - The text + the key.
* Loads the script used to fetch a client's hit count and expiry time.
*/
prefixKey(key: string): string {
return `${this.prefix}${key}`
async loadGetScript(): Promise<string> {
const result = await this.sendCommand('SCRIPT', 'LOAD', scripts.get)

if (typeof result !== 'string') {
throw new TypeError('unexpected reply from redis client')
}

return result
}

/**
* Method that actually initializes the store.
*
* @param options {RateLimitConfiguration} - The options used to setup the middleware.
* Runs the increment command, and retries it if the script is not loaded.
*/
init(options: RateLimitConfiguration) {
this.windowMs = options.windowMs
}

async runCommandWithRetry(key: string) {
async retryableIncrement(key: string): Promise<RedisReply> {
const evalCommand = async () =>
this.sendCommand(
'EVALSHA',
await this.loadedScriptSha1,
await this.incrementScriptSha,
'1',
this.prefixKey(key),
this.resetExpiryOnChange ? '1' : '0',
Expand All @@ -121,44 +139,59 @@ class RedisStore implements Store {
return result
} catch {
// TODO: distinguish different error types
this.loadedScriptSha1 = this.loadScript()
this.incrementScriptSha = this.loadIncrementScript()
return evalCommand()
}
}

/**
* Method to increment a client's hit counter.
* Method to prefix the keys with the given text.
*
* @param key {string} - The identifier for a client
* @param key {string} - The key.
*
* @returns {IncrementResponse} - The number of hits and reset time for that client
* @returns {string} - The text + the key.
*/
async increment(key: string): Promise<IncrementResponse> {
const results = await this.runCommandWithRetry(key)

if (!Array.isArray(results)) {
throw new TypeError('Expected result to be array of values')
}
prefixKey(key: string): string {
return `${this.prefix}${key}`
}

if (results.length !== 2) {
throw new Error(`Expected 2 replies, got ${results.length}`)
}
/**
* Method that actually initializes the store.
*
* @param options {RateLimitConfiguration} - The options used to setup the middleware.
*/
init(options: RateLimitConfiguration) {
this.windowMs = options.windowMs
}

const totalHits = results[0]
if (typeof totalHits !== 'number') {
throw new TypeError('Expected value to be a number')
}
/**
* Method to fetch a client's hit count and reset time.
*
* @param key {string} - The identifier for a client.
*
* @returns {ClientRateLimitInfo | undefined} - The number of hits and reset time for that client.
*/
async get(key: string): Promise<ClientRateLimitInfo | undefined> {
const results = await this.sendCommand(
'EVALSHA',
await this.getScriptSha,
'1',
this.prefixKey(key),
)

const timeToExpire = results[1]
if (typeof timeToExpire !== 'number') {
throw new TypeError('Expected value to be a number')
}
return parseScriptResponse(results)
}

const resetTime = new Date(Date.now() + timeToExpire)
return {
totalHits,
resetTime,
}
/**
* Method to increment a client's hit counter.
*
* @param key {string} - The identifier for a client
*
* @returns {IncrementResponse} - The number of hits and reset time for that client
*/
async increment(key: string): Promise<IncrementResponse> {
const results = await this.retryableIncrement(key)
return parseScriptResponse(results)
}

/**
Expand All @@ -180,4 +213,5 @@ class RedisStore implements Store {
}
}

// Export it to the world!
export default RedisStore
35 changes: 35 additions & 0 deletions source/scripts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// /source/scripts.ts
// The lua scripts for the increment and get operations.

/**
* The lua scripts, used to make consecutive queries on the same key and avoid
* race conditions by doing all the work on the redis server.
*/
const scripts = {
increment: `
local totalHits = redis.call("INCR", KEYS[1])
local timeToExpire = redis.call("PTTL", KEYS[1])
if timeToExpire <= 0 or ARGV[1] == "1"
then
redis.call("PEXPIRE", KEYS[1], tonumber(ARGV[2]))
timeToExpire = tonumber(ARGV[2])
end
return { totalHits, timeToExpire }
`
// Ensure that code changes that affect whitespace do not affect
// the script contents.
.replaceAll(/^\s+/gm, '')
.trim(),
get: `
local totalHits = redis.call("GET", KEYS[1])
local timeToExpire = redis.call("PTTL", KEYS[1])
return { totalHits, timeToExpire }
`
.replaceAll(/^\s+/gm, '')
.trim(),
}

// Export them so we can use them in the `lib.ts` file.
export default scripts
7 changes: 3 additions & 4 deletions source/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,14 @@
/**
* The type of data Redis might return to us.
*/
export type RedisReply = number | string
type Data = boolean | number | string
export type RedisReply = Data | Data[]

/**
* The library sends Redis raw commands, so all we need to know are the
* 'raw-command-sending' functions for each redis client.
*/
export type SendCommandFn = (
...args: string[]
) => Promise<RedisReply | RedisReply[]>
export type SendCommandFn = (...args: string[]) => Promise<RedisReply>

/**
* The configuration options for the store.
Expand Down
44 changes: 32 additions & 12 deletions test/store-test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@
// The tests for the store.

import { createHash } from 'node:crypto'
import { jest } from '@jest/globals'
import { expect, jest } from '@jest/globals'
import { type Options } from 'express-rate-limit'
import MockRedisClient from 'ioredis-mock'
import RedisStore, { type RedisReply } from '../source/index.js'

// The SHA of the script to evaluate
let scriptSha: string | undefined
// The mock redis client to use.
const client = new MockRedisClient()

Expand All @@ -18,29 +16,33 @@ const client = new MockRedisClient()
*
* @param {string[]} ...args - The raw command to send.
*
* @return {RedisReply | RedisReply[]} The reply returned by Redis.
* @return {RedisReply} The reply returned by Redis.
*/
const sendCommand = async (
...args: string[]
): Promise<RedisReply | RedisReply[]> => {
const sendCommand = async (...args: string[]): Promise<RedisReply> => {
// `SCRIPT LOAD`, called when the store is initialized. This loads the lua script
// for incrementing a client's hit counter.
if (args[0] === 'SCRIPT') {
// `ioredis-mock` doesn't have a `SCRIPT LOAD` function, so we have to compute
// the SHA manually and `EVAL` the script to get it saved.
const shasum = createHash('sha1')
shasum.update(args[2])
scriptSha = shasum.digest('hex')
await client.eval(args[2], 1, '__test', '0', '100')
const sha = shasum.digest('hex')

const testArgs = args[2].includes('INCR')
? ['__test_incr', '0', '10']
: ['__test_get']
await client.eval(args[2], 1, ...testArgs)

// Return the SHA to the store.
return scriptSha
return sha
}

// `EVALSHA` executes the script that was loaded already with the given arguments
if (args[0] === 'EVALSHA')
if (args[0] === 'EVALSHA') {
// @ts-expect-error Wrong types :/
return client.evalsha(scriptSha!, ...args.slice(2)) as number[]
return client.evalsha(...args.slice(1)) as number[]
}

// `DECR` decrements the count for a client.
if (args[0] === 'DECR') return client.decr(args[1])
// `DEL` resets the count for a client by deleting the key.
Expand Down Expand Up @@ -128,6 +130,7 @@ describe('redis store test', () => {
const key = 'test-store'

await store.increment(key) // => 1
await store.increment(key) // => 2
await store.resetKey(key) // => undefined

const { totalHits } = await store.increment(key) // => 1
Expand All @@ -139,6 +142,23 @@ describe('redis store test', () => {
expect(Number(await client.pttl('rl:test-store'))).toEqual(10)
})

it('fetches the count for a key in the store when `getKey` is called', async () => {
const store = new RedisStore({ sendCommand })
store.init({ windowMs: 10 } as Options)

const key = 'test-store'

await store.increment(key) // => 1
await store.increment(key) // => 2
const info = await store.get(key)

// Ensure the hit count is 1, and that `resetTime` is a date.
expect(info).toMatchObject({
totalHits: 2,
resetTime: expect.any(Date),
})
})

it('resets expiry time on change if `resetExpiryOnChange` is set to `true`', async () => {
const store = new RedisStore({ sendCommand, resetExpiryOnChange: true })
store.init({ windowMs: 60 } as Options)
Expand Down

0 comments on commit f1880d9

Please sign in to comment.