Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stop generation button (closes #86) #88

Merged
merged 11 commits into from
Apr 26, 2023
17 changes: 17 additions & 0 deletions src/lib/components/StopGeneratingBtn.svelte
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
<script lang="ts">
import CarbonPause from "~icons/carbon/pause-filled";

export let visible: boolean = false;
export let className = "";
</script>

<button
type="button"
on:click
class="absolute btn flex rounded-lg border py-1 px-3 shadow-sm bg-white dark:bg-gray-700 hover:bg-gray-100 dark:hover:bg-gray-600 dark:border-gray-600 transition-all
{className}
{visible ? 'opacity-100 visible' : 'opacity-0 invisible'}
"
>
<CarbonPause class="mr-1 -ml-1 w-[1.1875rem] h-[1.25rem] text-gray-400" /> Stop generating
</button>
12 changes: 9 additions & 3 deletions src/lib/components/chat/ChatWindow.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import { createEventDispatcher } from "svelte";

import CarbonSendAltFilled from "~icons/carbon/send-alt-filled";
import CarbonExport from "~icons/carbon/export";

import ChatMessages from "./ChatMessages.svelte";
import ChatInput from "./ChatInput.svelte";
import CarbonExport from "~icons/carbon/export";
import StopGeneratingBtn from "../StopGeneratingBtn.svelte";
import { PUBLIC_MODEL_ID, PUBLIC_MODEL_NAME } from "$env/static/public";

export let messages: Message[] = [];
Expand All @@ -16,7 +17,7 @@

let message: string;

const dispatch = createEventDispatcher<{ message: string; share: void }>();
const dispatch = createEventDispatcher<{ message: string; share: void; stop: void }>();

const handleSubmit = () => {
if (loading) return;
Expand All @@ -28,8 +29,13 @@
<div class="relative min-h-0 min-w-0">
<ChatMessages {loading} {pending} {messages} on:message />
<div
class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full"
class="flex flex-col pointer-events-none [&>*]:pointer-events-auto max-md:border-t dark:border-gray-800 items-center max-md:dark:bg-gray-900 max-md:bg-white bg-gradient-to-t from-white via-white/80 to-white/0 dark:from-gray-900 dark:via-gray-80 dark:to-gray-900/0 justify-center absolute inset-x-0 max-w-3xl xl:max-w-4xl mx-auto px-3.5 sm:px-5 bottom-0 py-4 md:py-8 w-full z-0"
>
<StopGeneratingBtn
visible={loading}
className="right-5 mr-[1px] md:mr-0 md:right-7 top-6 md:top-10 z-10"
on:click={() => dispatch("stop")}
/>
<form
on:submit|preventDefault={handleSubmit}
class="w-full relative flex items-center rounded-xl flex-1 max-w-4xl border bg-gray-100 focus-within:border-gray-300 dark:bg-gray-700 dark:border-gray-600 dark:focus-within:border-gray-500 "
Expand Down
29 changes: 29 additions & 0 deletions src/lib/server/abortedGenerations.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Shouldn't be needed if we dove into sveltekit internals, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850

import { setTimeout } from "node:timers/promises";
import { collections } from "./database";

let closed = false;
process.on("SIGINT", () => {
closed = true;
});

export let abortedGenerations: Map<string, Date> = new Map();

async function maintainAbortedGenerations() {
while (!closed) {
await setTimeout(1000);

try {
const aborts = await collections.abortedGenerations.find({}).sort({ createdAt: 1 }).toArray();

abortedGenerations = new Map(
aborts.map(({ conversationId, createdAt }) => [conversationId.toString(), createdAt])
);
} catch (err) {
console.error(err);
}
}
}

maintainAbortedGenerations();
6 changes: 5 additions & 1 deletion src/lib/server/database.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { MONGODB_URL, MONGODB_DB_NAME } from "$env/static/private";
import { MongoClient } from "mongodb";
import type { Conversation } from "$lib/types/Conversation";
import type { SharedConversation } from "$lib/types/SharedConversation";
import type { AbortedGeneration } from "$lib/types/AbortedGeneration";

const client = new MongoClient(MONGODB_URL, {
// directConnection: true
Expand All @@ -13,11 +14,14 @@ const db = client.db(MONGODB_DB_NAME);

const conversations = db.collection<Conversation>("conversations");
const sharedConversations = db.collection<SharedConversation>("sharedConversations");
const abortedGenerations = db.collection<AbortedGeneration>("abortedGenerations");

export { client, db };
export const collections = { conversations, sharedConversations };
export const collections = { conversations, sharedConversations, abortedGenerations };

client.on("open", () => {
conversations.createIndex({ sessionId: 1, updatedAt: -1 });
abortedGenerations.createIndex({ updatedAt: 1 }, { expireAfterSeconds: 30 });
abortedGenerations.createIndex({ conversationId: 1 }, { unique: true });
sharedConversations.createIndex({ hash: 1 }, { unique: true });
});
9 changes: 9 additions & 0 deletions src/lib/types/AbortedGeneration.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// Ideally shouldn't be needed, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850

import type { Conversation } from "./Conversation";

export interface AbortedGeneration {
createdAt: Date;
updatedAt: Date;
conversationId: Conversation["_id"];
}
12 changes: 12 additions & 0 deletions src/lib/utils/concatUint8Arrays.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import { sum } from "./sum";

export function concatUint8Arrays(arrays: Uint8Array[]): Uint8Array {
const totalLength = sum(arrays.map((a) => a.length));
const result = new Uint8Array(totalLength);
let offset = 0;
for (const array of arrays) {
result.set(array, offset);
offset += array.length;
}
return result;
}
22 changes: 21 additions & 1 deletion src/routes/conversation/[id]/+page.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

let messages = data.messages;
let lastLoadedMessages = data.messages;
let isAborted = false;

// Since we modify the messages array locally, we don't want to reset it if an old version is passed
$: if (data.messages !== lastLoadedMessages) {
Expand Down Expand Up @@ -55,7 +56,24 @@
for await (const data of response) {
pending = false;

if (!data || conversationId !== $page.params.id) break;
if (!data) {
break;
}

if (conversationId !== $page.params.id) {
fetch(`${base}/conversation/${conversationId}/stop-generating`, {
method: "POST",
}).catch(console.error);
break;
}

if (isAborted) {
isAborted = false;
fetch(`${base}/conversation/${conversationId}/stop-generating`, {
method: "POST",
}).catch(console.error);
break;
}

// final message
if (data.generated_text) {
Expand Down Expand Up @@ -91,6 +109,7 @@
if (!message.trim()) return;

try {
isAborted = false;
loading = true;
pending = true;

Expand Down Expand Up @@ -130,4 +149,5 @@
{messages}
on:message={(message) => writeMessage(message.detail)}
on:share={() => shareConversation($page.params.id, data.title)}
on:stop={() => (isAborted = true)}
/>
46 changes: 37 additions & 9 deletions src/routes/conversation/[id]/+server.ts
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
import { PUBLIC_SEP_TOKEN } from "$env/static/public";
import { buildPrompt } from "$lib/buildPrompt.js";
import { abortedGenerations } from "$lib/server/abortedGenerations.js";
import { collections } from "$lib/server/database.js";
import { modelEndpoint } from "$lib/server/modelEndpoint.js";
import type { Message } from "$lib/types/Message.js";
import { concatUint8Arrays } from "$lib/utils/concatUint8Arrays.js";
import { streamToAsyncIterable } from "$lib/utils/streamToAsyncIterable";
import { sum } from "$lib/utils/sum";
import { trimPrefix } from "$lib/utils/trimPrefix.js";
import { trimSuffix } from "$lib/utils/trimSuffix.js";
import type { TextGenerationStreamOutput } from "@huggingface/inference";
import { error } from "@sveltejs/kit";
import { ObjectId } from "mongodb";

export async function POST({ request, fetch, locals, params }) {
// todo: add validation on params.id
const convId = new ObjectId(params.id);
const date = new Date();

const conv = await collections.conversations.findOne({
_id: convId,
Expand All @@ -31,6 +34,8 @@ export async function POST({ request, fetch, locals, params }) {

const randomEndpoint = modelEndpoint();

const abortController = new AbortController();

const resp = await fetch(randomEndpoint.endpoint, {
headers: {
"Content-Type": request.headers.get("Content-Type") ?? "application/json",
Expand All @@ -41,12 +46,13 @@ export async function POST({ request, fetch, locals, params }) {
...json,
inputs: prompt,
}),
signal: abortController.signal,
});

const [stream1, stream2] = resp.body!.tee();

async function saveMessage() {
let generated_text = await parseGeneratedText(stream2);
let generated_text = await parseGeneratedText(stream2, convId, date, abortController);

// We could also check if PUBLIC_ASSISTANT_MESSAGE_TOKEN is present and use it to slice the text
if (generated_text.startsWith(prompt)) {
Expand Down Expand Up @@ -97,19 +103,41 @@ export async function DELETE({ locals, params }) {
return new Response();
}

async function parseGeneratedText(stream: ReadableStream): Promise<string> {
async function parseGeneratedText(
stream: ReadableStream,
conversationId: ObjectId,
promptedAt: Date,
abortController: AbortController
): Promise<string> {
const inputs: Uint8Array[] = [];
for await (const input of streamToAsyncIterable(stream)) {
inputs.push(input);

const date = abortedGenerations.get(conversationId.toString());

if (date && date > promptedAt) {
abortController.abort("Cancelled by user");
const completeInput = concatUint8Arrays(inputs);

const lines = new TextDecoder()
.decode(completeInput)
.split("\n")
.filter((line) => line.startsWith("data:"));

const tokens = lines.map((line) => {
try {
const json: TextGenerationStreamOutput = JSON.parse(line.slice("data:".length));
return json.token.text;
} catch {
return "";
}
});
return tokens.join("");
}
}

// Merge inputs into a single Uint8Array
const completeInput = new Uint8Array(sum(inputs.map((input) => input.length)));
let offset = 0;
for (const input of inputs) {
completeInput.set(input, offset);
offset += input.length;
}
const completeInput = concatUint8Arrays(inputs);

// Get last line starting with "data:" and parse it as JSON to get the generated text
const message = new TextDecoder().decode(completeInput);
Expand Down
27 changes: 27 additions & 0 deletions src/routes/conversation/[id]/stop-generating/+server.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import { collections } from "$lib/server/database";
import { error } from "@sveltejs/kit";
import { ObjectId } from "mongodb";

/**
* Ideally, we'd be able to detect the client-side abort, see https://github.com/huggingface/chat-ui/pull/88#issuecomment-1523173850
*/
export async function POST({ params, locals }) {
const conversationId = new ObjectId(params.id);

const conversation = await collections.conversations.findOne({
_id: conversationId,
sessionId: locals.sessionId,
});

if (!conversation) {
throw error(404, "Conversation not found");
}

await collections.abortedGenerations.updateOne(
{ conversationId },
{ $set: { updatedAt: new Date() }, $setOnInsert: { createdAt: new Date() } },
{ upsert: true }
);

return new Response();
}