Skip to content

Commit

Permalink
feat: Allow user-defined characteristics on rate limit options (#203)
Browse files Browse the repository at this point in the history
Closes #202 
Closes arcjet/arcjet#597

This adds TypeScript support for custom characteristics. While it was supported via the protocol, there was no indication to the user that they needed to add the field to every request.

By changing the types of our primitives, we can make the prop required on the request object (while filtering our well-known characteristics).

Draft because it needs some tests.
  • Loading branch information
blaine-arcjet authored Feb 7, 2024
1 parent a945b2c commit dc5b001
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 40 deletions.
145 changes: 110 additions & 35 deletions arcjet/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ function errorMessage(err: unknown): string {
// https://github.com/sindresorhus/type-fest/blob/964466c9d59c711da57a5297ad954c13132a0001/source/simplify.d.ts
// UnionToIntersection:
// https://github.com/sindresorhus/type-fest/blob/017bf38ebb52df37c297324d97bcc693ec22e920/source/union-to-intersection.d.ts
// IsNever:
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/primitive.d.ts
// LiteralCheck & IsStringLiteral:
// https://github.com/sindresorhus/type-fest/blob/e02f228f6391bb2b26c32a55dfe1e3aa2386d515/source/is-literal.d.ts
//
// Licensed: MIT License Copyright (c) Sindre Sorhus <sindresorhus@gmail.com>
// (https://sindresorhus.com)
Expand Down Expand Up @@ -149,6 +153,25 @@ type UnionToIntersection<Union> =
? // The `& Union` is to allow indexing by the resulting type
Intersection & Union
: never;
type IsNever<T> = [T] extends [never] ? true : false;
type LiteralCheck<
T,
LiteralType extends
| null
| undefined
| string
| number
| boolean
| symbol
| bigint,
> = IsNever<T> extends false // Must be wider than `never`
? [T] extends [LiteralType] // Must be narrower than `LiteralType`
? [LiteralType] extends [T] // Cannot be wider than `LiteralType`
? false
: true
: false
: false;
type IsStringLiteral<T> = LiteralCheck<T, string>;

export interface RemoteClient {
decide(
Expand Down Expand Up @@ -417,30 +440,31 @@ function runtime(): Runtime {
}
}

type TokenBucketRateLimitOptions = {
type TokenBucketRateLimitOptions<Characteristics extends readonly string[]> = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
characteristics?: Characteristics;
refillRate: number;
interval: string | number;
capacity: number;
};

type FixedWindowRateLimitOptions = {
type FixedWindowRateLimitOptions<Characteristics extends readonly string[]> = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
characteristics?: Characteristics;
window: string | number;
max: number;
};

type SlidingWindowRateLimitOptions = {
mode?: ArcjetMode;
match?: string;
characteristics?: string[];
interval: string | number;
max: number;
};
type SlidingWindowRateLimitOptions<Characteristics extends readonly string[]> =
{
mode?: ArcjetMode;
match?: string;
characteristics?: Characteristics;
interval: string | number;
max: number;
};

/**
* Bot detection is disabled by default. The `bots` configuration block allows
Expand Down Expand Up @@ -550,6 +574,25 @@ type PlainObject = { [key: string]: unknown };
export type Primitive<Props extends PlainObject = {}> = ArcjetRule<Props>[];
export type Product<Props extends PlainObject = {}> = ArcjetRule<Props>[];

// User-defined characteristics alter the required props of an ArcjetRequest
// Note: If a user doesn't provide the object literal to our primitives
// directly, we fallback to no required props. They can opt-in by adding the
// `as const` suffix to the characteristics array.
type PropsForCharacteristic<T> = IsStringLiteral<T> extends true
? T extends
| "ip.src"
| "http.host"
| "http.method"
| "http.request.uri.path"
| `http.request.headers["${string}"]`
| `http.request.cookie["${string}"]`
| `http.request.uri.args["${string}"]`
? {}
: T extends string
? Record<T, string | number | boolean>
: never
: {};
// Rules can specify they require specific props on an ArcjetRequest
type PropsForRule<R> = R extends ArcjetRule<infer Props> ? Props : {};
// We theoretically support an arbitrary amount of rule flattening,
// but one level seems to be easiest; however, this puts a constraint of
Expand Down Expand Up @@ -590,10 +633,18 @@ function isLocalRule<Props extends PlainObject>(
);
}

export function tokenBucket(
options?: TokenBucketRateLimitOptions,
...additionalOptions: TokenBucketRateLimitOptions[]
): Primitive<{ requested: number }> {
export function tokenBucket<
const Characteristics extends readonly string[] = [],
>(
options?: TokenBucketRateLimitOptions<Characteristics>,
...additionalOptions: TokenBucketRateLimitOptions<Characteristics>[]
): Primitive<
Simplify<
UnionToIntersection<
{ requested: number } | PropsForCharacteristic<Characteristics[number]>
>
>
> {
const rules: ArcjetTokenBucketRateLimitRule<{ requested: number }>[] = [];

if (typeof options === "undefined") {
Expand All @@ -603,7 +654,9 @@ export function tokenBucket(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const refillRate = opt.refillRate;
const interval = duration.parse(opt.interval);
Expand All @@ -625,10 +678,14 @@ export function tokenBucket(
return rules;
}

export function fixedWindow(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
export function fixedWindow<
const Characteristics extends readonly string[] = [],
>(
options?: FixedWindowRateLimitOptions<Characteristics>,
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
): Primitive<
Simplify<UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>>
> {
const rules: ArcjetFixedWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
Expand All @@ -638,7 +695,9 @@ export function fixedWindow(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const max = opt.max;
const window = duration.parse(opt.window);
Expand All @@ -660,19 +719,25 @@ export function fixedWindow(

// This is currently kept for backwards compatibility but should be removed in
// favor of the fixedWindow primitive.
export function rateLimit(
options?: FixedWindowRateLimitOptions,
...additionalOptions: FixedWindowRateLimitOptions[]
): Primitive {
export function rateLimit<const Characteristics extends readonly string[] = []>(
options?: FixedWindowRateLimitOptions<Characteristics>,
...additionalOptions: FixedWindowRateLimitOptions<Characteristics>[]
): Primitive<
Simplify<UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>>
> {
// TODO(#195): We should also have a local rate limit using an in-memory data
// structure if the environment supports it
return fixedWindow(options, ...additionalOptions);
}

export function slidingWindow(
options?: SlidingWindowRateLimitOptions,
...additionalOptions: SlidingWindowRateLimitOptions[]
): Primitive {
export function slidingWindow<
const Characteristics extends readonly string[] = [],
>(
options?: SlidingWindowRateLimitOptions<Characteristics>,
...additionalOptions: SlidingWindowRateLimitOptions<Characteristics>[]
): Primitive<
Simplify<UnionToIntersection<PropsForCharacteristic<Characteristics[number]>>>
> {
const rules: ArcjetSlidingWindowRateLimitRule<{}>[] = [];

if (typeof options === "undefined") {
Expand All @@ -682,7 +747,9 @@ export function slidingWindow(
for (const opt of [options, ...additionalOptions]) {
const mode = opt.mode === "LIVE" ? "LIVE" : "DRY_RUN";
const match = opt.match;
const characteristics = opt.characteristics;
const characteristics = Array.isArray(opt.characteristics)
? opt.characteristics
: undefined;

const max = opt.max;
const interval = duration.parse(opt.interval);
Expand Down Expand Up @@ -867,15 +934,23 @@ export function detectBot(
return rules;
}

export type ProtectSignupOptions = {
rateLimit?: SlidingWindowRateLimitOptions | SlidingWindowRateLimitOptions[];
export type ProtectSignupOptions<Characteristics extends string[]> = {
rateLimit?:
| SlidingWindowRateLimitOptions<Characteristics>
| SlidingWindowRateLimitOptions<Characteristics>[];
bots?: BotOptions | BotOptions[];
email?: EmailOptions | EmailOptions[];
};

export function protectSignup(
options?: ProtectSignupOptions,
): Product<{ email: string }> {
export function protectSignup<const Characteristics extends string[] = []>(
options?: ProtectSignupOptions<Characteristics>,
): Product<
Simplify<
UnionToIntersection<
{ email: string } | PropsForCharacteristic<Characteristics[number]>
>
>
> {
let rateLimitRules: Primitive<{}> = [];
if (Array.isArray(options?.rateLimit)) {
rateLimitRules = slidingWindow(...options.rateLimit);
Expand Down
29 changes: 24 additions & 5 deletions arcjet/test/index.edge.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,28 @@ describe("Arcjet: Env = Edge runtime", () => {
rules: [
// Test rules
foobarbaz(),
tokenBucket({
refillRate: 1,
interval: 1,
capacity: 1,
}),
tokenBucket(
{
characteristics: [
"ip.src",
"http.host",
"http.method",
"http.request.uri.path",
`http.request.headers["abc"]`,
`http.request.cookie["xyz"]`,
`http.request.uri.args["foobar"]`,
],
refillRate: 1,
interval: 1,
capacity: 1,
},
{
characteristics: ["userId"],
refillRate: 1,
interval: 1,
capacity: 1,
},
),
rateLimit({
max: 1,
window: "60s",
Expand All @@ -61,6 +78,8 @@ describe("Arcjet: Env = Edge runtime", () => {
path: "",
headers: new Headers(),
extra: {},
userId: "user123",
foobar: 123,
});

expect(decision.isErrored()).toBe(false);
Expand Down
Loading

0 comments on commit dc5b001

Please sign in to comment.