From 1854c7af15c21bfea14ea9431a86e04c094d10fa Mon Sep 17 00:00:00 2001 From: Ajeya Bhat Date: Mon, 9 Oct 2023 19:07:10 +0530 Subject: [PATCH 1/2] added support for abort and chat completion types --- src/comet.ts | 15 +++++++++++---- src/helpers.ts | 15 ++++++++++----- src/types/comet.ts | 1 + src/utils/inference/vllm-stream.ts | 22 +++++++++++++++------- 4 files changed, 37 insertions(+), 16 deletions(-) diff --git a/src/comet.ts b/src/comet.ts index 52d4e51..8f830f5 100644 --- a/src/comet.ts +++ b/src/comet.ts @@ -53,7 +53,8 @@ export class Comet implements IComet { this.cometId, this.apiKey, payload, - handleNewText + handleNewText, + options?.signal ); // @ts-ignore @@ -68,7 +69,8 @@ export class Comet implements IComet { this.cometId, this.apiKey, payload, - handleNewText + handleNewText, + options?.signal ); // @ts-ignore if (resp.error) { @@ -79,7 +81,12 @@ export class Comet implements IComet { } } } else { - const resp = await streamPromptWithAxios(this.cometAPI, payload, handleNewText); + const resp = await streamPromptWithAxios( + this.cometAPI, + payload, + handleNewText, + options?.signal + ); // @ts-ignore if (resp.error) { // @ts-ignore @@ -89,7 +96,7 @@ export class Comet implements IComet { } } } else { - const { data } = await this.cometAPI.post(`/prompt`, payload); + const { data } = await this.cometAPI.post(`/prompt`, payload, { signal: options?.signal }); return data as TCometPromptResponse; } } diff --git a/src/helpers.ts b/src/helpers.ts index 7382879..b3633d8 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -19,7 +19,8 @@ export const streamPromptWithNativeFetch = ( cometId: string, apiKey: string, payload: PromptPayload, - handleNewText?: (token: string) => void | Promise + handleNewText?: (token: string) => void | Promise, + signal?: AbortSignal ) => { return new Promise((resolve, reject) => { (async () => { @@ -31,6 +32,7 @@ export const streamPromptWithNativeFetch = ( Authorization: `Bearer ${apiKey}`, Accept: 'text/plain, application/json', }, + signal, }); if (response.ok) { if (!response.body) return reject('No response body found.'); @@ -77,12 +79,14 @@ export const streamPromptWithNativeFetch = ( export const streamPromptWithAxios = ( cometAPI: AxiosInstance, payload: PromptPayload, - handleNewText?: (token: string) => void | Promise + handleNewText?: (token: string) => void | Promise, + signal?: AbortSignal ) => { return new Promise((resolve, reject) => { (async () => { const { data: stream } = await cometAPI.post(`/prompt`, payload, { responseType: 'stream', + signal, }); let responsePrefixReceived = false; let responseText: string = ''; @@ -117,7 +121,8 @@ export const streamPromptWithEventStreaming = async ( cometId: string, apiKey: string, payload: PromptPayload, - handleNewText?: (token: string) => void | Promise + handleNewText?: (token: string) => void | Promise, + signal?: AbortSignal ): Promise => { try { let finalResponse: TCometPromptResponse | TCometPromptStreamResponseError; @@ -128,6 +133,7 @@ export const streamPromptWithEventStreaming = async ( Accept: 'text/event-stream', Authorization: `Bearer ${apiKey}`, }, + signal, body: JSON.stringify(payload), async onopen(response) { const contentType = response.headers.get('content-type'); @@ -156,8 +162,7 @@ export const streamPromptWithEventStreaming = async ( } catch (e) { throw new ClientError('Encountered error while parsing response into JSON.'); } - } else if (msg.event==='error') { - + } else if (msg.event === 'error') { } }, // onclose() { diff --git a/src/types/comet.ts b/src/types/comet.ts index d09c66b..36da0d2 100644 --- a/src/types/comet.ts +++ b/src/types/comet.ts @@ -32,6 +32,7 @@ export interface IComet { export interface PromptOptions { useNativeFetch?: boolean; + signal?: AbortSignal; } export interface PromptPayload { diff --git a/src/utils/inference/vllm-stream.ts b/src/utils/inference/vllm-stream.ts index e091d7d..240f85a 100644 --- a/src/utils/inference/vllm-stream.ts +++ b/src/utils/inference/vllm-stream.ts @@ -1,11 +1,12 @@ import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'; import { APIError } from 'error'; -import { VLLMOpenAICompletionsOutputType, VLLMPromptParameters } from 'types/inference'; +import { VLLMPromptParameters } from 'types/inference'; export const streamGenericInferenceServer = ( domain: string, payload: { prompt: string } & VLLMPromptParameters, - handleNewChunk?: (chunk: string) => void | Promise + handleNewChunk?: (chunk: string) => void | Promise, + signal?: AbortSignal ) => { return new Promise((resolve, reject) => { (async () => { @@ -16,6 +17,7 @@ export const streamGenericInferenceServer = ( 'Content-Type': 'application/json', Accept: 'text/plain, application/json', }, + signal, }); if (response.ok) { if (!response.body) return reject('No response body found.'); @@ -88,12 +90,17 @@ export const streamGenericInferenceServer = ( class ClientError extends Error {} // class FatalError extends Error {} -export const streamOpenAIInferenceServer = async ( - payload: { prompt: string; model: string } & VLLMPromptParameters, +export type ChatMessage = { role: 'system' | 'user' | 'assistant'; content: string }; +export async function streamOpenAIInferenceServer( + payload: { + prompt: T extends 'chat' ? ChatMessage[] : T extends 'text' ? string : never; + model: string; + } & VLLMPromptParameters, domain: string, - type: 'chat' | 'text', - handleNewChunk?: (chunk: string) => void | Promise -): Promise => { + type: T, + handleNewChunk?: (chunk: string) => void | Promise, + signal?: AbortSignal +): Promise { try { let finalResponse: string; await fetchEventSource(`${domain}/v1/${type === 'chat' ? 'chat/completions' : 'completions'}`, { @@ -102,6 +109,7 @@ export const streamOpenAIInferenceServer = async ( 'Content-Type': 'application/json', Accept: 'text/event-stream', }, + signal, body: JSON.stringify({ ...payload, stream: true }), async onopen(response) { const contentType = response.headers.get('content-type'); From 8683f0ceec664768a4665fd28c85a2485885f08f Mon Sep 17 00:00:00 2001 From: Ajeya Bhat Date: Tue, 10 Oct 2023 14:46:57 +0530 Subject: [PATCH 2/2] changeset --- .changeset/thin-waves-breathe.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/thin-waves-breathe.md diff --git a/.changeset/thin-waves-breathe.md b/.changeset/thin-waves-breathe.md new file mode 100644 index 0000000..6ecde73 --- /dev/null +++ b/.changeset/thin-waves-breathe.md @@ -0,0 +1,5 @@ +--- +'outpostkit': patch +--- + +added abort feature