Skip to content

Commit

Permalink
Bypass spaces tools rate limit (#1167)
Browse files Browse the repository at this point in the history
* Bypass spaces tools rate limit

* DRY

* optimize

* use `context` for userId & ip

* lint

* user `userName`

* fix merge conflict

* rn `userName` -> `username`
  • Loading branch information
Mishig authored May 28, 2024
1 parent 723982a commit 30b4335
Show file tree
Hide file tree
Showing 11 changed files with 73 additions and 32 deletions.
1 change: 1 addition & 0 deletions chart/env/prod.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,7 @@ externalSecrets:
ADMIN_API_SECRET: "hub-prod-chat-ui-admin-api-secret"
USAGE_LIMITS: "hub-prod-chat-ui-usage-limits"
MESSAGES_BEFORE_LOGIN: "hub-prod-chat-ui-messages-before-login"
IP_TOKEN_SECRET: "hub-prod-chat-ui-ip-token-secret"

autoscaling:
enabled: true
Expand Down
15 changes: 12 additions & 3 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
"highlight.js": "^11.7.0",
"image-size": "^1.0.2",
"ip-address": "^9.0.5",
"jose": "^5.3.0",
"jsdom": "^22.0.0",
"json5": "^2.2.3",
"marked": "^12.0.1",
Expand Down
14 changes: 5 additions & 9 deletions src/lib/server/textGeneration/tools.ts
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ export function pickTools(
}

async function* runTool(
{ conv, messages, preprompt, assistant }: BackendToolContext,
ctx: BackendToolContext,
tools: BackendTool[],
call: ToolCall
): AsyncGenerator<MessageUpdate, ToolResult | undefined, undefined> {
Expand All @@ -74,12 +74,7 @@ async function* runTool(
};
try {
try {
const toolResult = yield* tool.call(call.parameters, {
conv,
messages,
preprompt,
assistant,
});
const toolResult = yield* tool.call(call.parameters, ctx);
if (toolResult.status === ToolResultStatus.Error) {
yield {
type: MessageUpdateType.Tool,
Expand Down Expand Up @@ -123,10 +118,11 @@ async function* runTool(
}

export async function* runTools(
{ endpoint, conv, messages, assistant }: TextGenerationContext,
ctx: TextGenerationContext,
tools: BackendTool[],
preprompt?: string
): AsyncGenerator<MessageUpdate, ToolResult[], undefined> {
const { endpoint, conv, messages, assistant, ip, username } = ctx;
const calls: ToolCall[] = [];

const messagesWithFilesPrompt = messages.map((message, idx) => {
Expand Down Expand Up @@ -181,7 +177,7 @@ export async function* runTools(
Date.now() - pickToolStartTime
);

const toolContext: BackendToolContext = { conv, messages, preprompt, assistant };
const toolContext: BackendToolContext = { conv, messages, preprompt, assistant, ip, username };
const toolResults: (ToolResult | undefined)[] = yield* mergeAsyncGenerators(
calls.map((call) => runTool(toolContext, tools, call))
);
Expand Down
2 changes: 2 additions & 0 deletions src/lib/server/textGeneration/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@ export interface TextGenerationContext {
webSearch: boolean;
toolsPreference: Record<string, boolean>;
promptedAt: Date;
ip: string;
username?: string;
}
9 changes: 6 additions & 3 deletions src/lib/server/tools/documentParser.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import type { BackendTool } from ".";
import { ToolResultStatus } from "$lib/types/Tool";
import { callSpace } from "./utils";
import { callSpace, getIpToken } from "./utils";
import { downloadFile } from "$lib/server/files/downloadFile";

type PdfParserInput = [Blob /* pdf */, string /* filename */];
Expand All @@ -23,7 +23,7 @@ const documentParser: BackendTool = {
required: true,
},
},
async *call({ fileMessageIndex, fileIndex }, { conv, messages }) {
async *call({ fileMessageIndex, fileIndex }, { conv, messages, ip, username }) {
fileMessageIndex = Number(fileMessageIndex);
fileIndex = Number(fileIndex);

Expand All @@ -47,10 +47,13 @@ const documentParser: BackendTool = {
.then((file) => fetch(`data:${file.mime};base64,${file.value}`))
.then((res) => res.blob());

const ipToken = await getIpToken(ip, username);

const outputs = await callSpace<PdfParserInput, PdfParserOutput>(
"huggingchat/document-parser",
"predict",
[fileBlob, file.name]
[fileBlob, file.name],
ipToken
);

let documentMarkdown = outputs[0];
Expand Down
9 changes: 6 additions & 3 deletions src/lib/server/tools/images/editing.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { BackendTool } from "..";
import { uploadFile } from "../../files/uploadFile";
import { ToolResultStatus } from "$lib/types/Tool";
import { MessageUpdateType } from "$lib/types/MessageUpdate";
import { callSpace, type GradioImage } from "../utils";
import { callSpace, getIpToken, type GradioImage } from "../utils";
import { downloadFile } from "$lib/server/files/downloadFile";

type ImageEditingInput = [
Expand Down Expand Up @@ -37,7 +37,7 @@ const imageEditing: BackendTool = {
required: true,
},
},
async *call({ prompt, fileMessageIndex, fileIndex }, { conv, messages }) {
async *call({ prompt, fileMessageIndex, fileIndex }, { conv, messages, ip, username }) {
prompt = String(prompt);
fileMessageIndex = Number(fileMessageIndex);
fileIndex = Number(fileIndex);
Expand Down Expand Up @@ -68,6 +68,8 @@ const imageEditing: BackendTool = {
.then((file) => fetch(`data:${file.mime};base64,${file.value}`))
.then((res) => res.blob());

const ipToken = await getIpToken(ip, username);

const outputs = await callSpace<ImageEditingInput, ImageEditingOutput>(
"multimodalart/cosxl",
"run_edit",
Expand All @@ -77,7 +79,8 @@ const imageEditing: BackendTool = {
"", // negative prompt
7, // guidance scale
20, // steps
]
],
ipToken
);

const outputImage = await fetch(outputs[0].url)
Expand Down
9 changes: 6 additions & 3 deletions src/lib/server/tools/images/generation.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import type { BackendTool } from "..";
import { uploadFile } from "../../files/uploadFile";
import { ToolResultStatus } from "$lib/types/Tool";
import { MessageUpdateType } from "$lib/types/MessageUpdate";
import { callSpace, type GradioImage } from "../utils";
import { callSpace, getIpToken, type GradioImage } from "../utils";

type ImageGenerationInput = [
number /* number (numeric value between 1 and 8) in 'Number of Images' Slider component */,
Expand Down Expand Up @@ -44,7 +44,9 @@ const imageGeneration: BackendTool = {
default: 1024,
},
},
async *call({ prompt, numberOfImages }, { conv }) {
async *call({ prompt, numberOfImages }, { conv, ip, username }) {
const ipToken = await getIpToken(ip, username);

const outputs = await callSpace<ImageGenerationInput, ImageGenerationOutput>(
"ByteDance/Hyper-SDXL-1Step-T2I",
"/process_image",
Expand All @@ -54,7 +56,8 @@ const imageGeneration: BackendTool = {
512, // number in 'Image Width' Number component
String(prompt), // prompt
Math.floor(Math.random() * 1000), // seed random
]
],
ipToken
);
const imageBlobs = await Promise.all(
outputs[0].map((output) =>
Expand Down
14 changes: 5 additions & 9 deletions src/lib/server/tools/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import type { Assistant } from "$lib/types/Assistant";
import type { Conversation } from "$lib/types/Conversation";
import type { Message } from "$lib/types/Message";
import type { MessageUpdate } from "$lib/types/MessageUpdate";
import type { Tool, ToolResultError, ToolResultSuccess } from "$lib/types/Tool";

Expand All @@ -11,13 +8,12 @@ import imageGeneration from "./images/generation";
import documentParser from "./documentParser";
import fetchUrl from "./web/url";
import websearch from "./web/search";
import type { TextGenerationContext } from "../textGeneration/types";

export interface BackendToolContext {
conv: Conversation;
messages: Message[];
preprompt?: string;
assistant?: Pick<Assistant, "rag" | "dynamicPrompt" | "generateSettings">;
}
export type BackendToolContext = Pick<
TextGenerationContext,
"conv" | "messages" | "assistant" | "ip" | "username"
> & { preprompt?: string };

// typescript can't narrow a discriminated union after applying a generic like Omit to it
// so we have to define the omitted types and create a new union
Expand Down
29 changes: 27 additions & 2 deletions src/lib/server/tools/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import { env } from "$env/dynamic/private";
import { Client } from "@gradio/client";
import { SignJWT } from "jose";

export type GradioImage = {
path: string;
Expand All @@ -16,14 +17,38 @@ type GradioResponse = {
export async function callSpace<TInput extends unknown[], TOutput extends unknown[]>(
name: string,
func: string,
parameters: TInput
parameters: TInput,
ipToken: string | undefined
): Promise<TOutput> {
const client = await Client.connect(name, {
class CustomClient extends Client {
fetch(input: RequestInfo | URL, init?: RequestInit): Promise<Response> {
init = init || {};
init.headers = {
...(init.headers || {}),
...(ipToken ? { "X-IP-Token": ipToken } : {}),
};
return super.fetch(input, init);
}
}

const client = await CustomClient.connect(name, {
hf_token: (env.HF_TOKEN ?? env.HF_ACCESS_TOKEN) as unknown as `hf_${string}`,
});
return await client
.predict(func, parameters)
.then((res) => (res as unknown as GradioResponse).data as TOutput);
}

export async function getIpToken(ip: string, username?: string) {
const ipTokenSecret = env.IP_TOKEN_SECRET;
if (!ipTokenSecret) {
return;
}
return await new SignJWT({ ip, user: username })
.setProtectedHeader({ alg: "HS256" })
.setIssuedAt()
.setExpirationTime("1m")
.sign(new TextEncoder().encode(ipTokenSecret));
}

export { toolHasName } from "$lib/utils/tools";
2 changes: 2 additions & 0 deletions src/routes/conversation/[id]/+server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,8 @@ export async function POST({ request, locals, params, getClientAddress }) {
webSearch: webSearch ?? false,
toolsPreference: toolsPreferences ?? {},
promptedAt,
ip: getClientAddress(),
username: locals.user?.username,
};
// run the text generation and send updates to the client
for await (const event of textGeneration(ctx)) await update(event);
Expand Down

0 comments on commit 30b4335

Please sign in to comment.