From 26ecaf82f3c7ef4f1ef3f5e1473322c735c898c9 Mon Sep 17 00:00:00 2001 From: Ajeya Bhat Date: Fri, 29 Sep 2023 12:53:57 +0530 Subject: [PATCH 1/2] added inference endpoint --- src/comet.ts | 2 +- src/helpers.ts | 2 + src/inference.ts | 92 ++++++++++++++++++ src/types.ts | 1 - src/types/comet.ts | 176 ++++++++++++++++++++++++++++++++++ src/types/inference.ts | 17 ++++ src/utils/inference/io.ts | 71 ++++++++++++++ src/utils/inference/stream.ts | 170 ++++++++++++++++++++++++++++++++ 8 files changed, 529 insertions(+), 2 deletions(-) create mode 100644 src/inference.ts create mode 100644 src/types/comet.ts create mode 100644 src/types/inference.ts create mode 100644 src/utils/inference/io.ts create mode 100644 src/utils/inference/stream.ts diff --git a/src/comet.ts b/src/comet.ts index 9c9d393..52d4e51 100644 --- a/src/comet.ts +++ b/src/comet.ts @@ -10,7 +10,7 @@ import type { TCometPromptResponse, ICometSession, PromptOptions, -} from './types'; +} from './types/comet'; import { streamPromptWithAxios, streamPromptWithEventStreaming, diff --git a/src/helpers.ts b/src/helpers.ts index 0af22ad..7382879 100644 --- a/src/helpers.ts +++ b/src/helpers.ts @@ -156,6 +156,8 @@ export const streamPromptWithEventStreaming = async ( } catch (e) { throw new ClientError('Encountered error while parsing response into JSON.'); } + } else if (msg.event==='error') { + } }, // onclose() { diff --git a/src/inference.ts b/src/inference.ts new file mode 100644 index 0000000..b27ce01 --- /dev/null +++ b/src/inference.ts @@ -0,0 +1,92 @@ +import axios, { type AxiosInstance } from 'axios'; +import { API_V1_URL } from './constants'; +// import { +// streamPromptWithAxios, +// streamPromptWithEventStreaming, +// streamPromptWithNativeFetch, +// } from 'helpers'; +import { IInference, PromptPayload, PromptOptions } from 'types/inference'; + +export class Inference implements IInference { + readonly apiKey: string; + readonly inferenceId: string; + private readonly inferenceAPI: AxiosInstance; + + constructor(apiKey: string, inferenceId: string) { + this.apiKey = apiKey; + this.inferenceId = inferenceId; + + this.inferenceAPI = axios.create({ + baseURL: `${API_V1_URL}/inference/${inferenceId}`, + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + }); + } + + async getInferenceInfo() { + const { data } = await this.inferenceAPI.get('/'); + return data; + } + + async prompt( + payload: PromptPayload, + handleNewText?: (data: string) => void | Promise, + options?: PromptOptions + ) { + return { tbd: true }; + //TODO: better error handling + // if (payload.stream) { + // if (typeof window !== 'undefined') { + // if (options?.useNativeFetch) { + // const resp = await streamPromptWithNativeFetch( + // this.cometId, + // this.apiKey, + // payload, + // handleNewText + // ); + // // @ts-ignore + // if (resp.error) { + // // @ts-ignore + // throw new Error(resp.error); + // } else { + // return resp as TCometPromptResponse; + // } + // } else { + // const resp = await streamPromptWithEventStreaming( + // this.inferenceAPI, + // this.apiKey, + // payload, + // handleNewText + // ); + // // @ts-ignore + // if (resp.error) { + // // @ts-ignore + // throw new Error(resp.error); + // } else { + // return resp as TCometPromptResponse; + // } + // } + // } else { + // const resp = await streamPromptWithAxios(this.inferenceAPI, payload, handleNewText); + // // @ts-ignore + // if (resp.error) { + // // @ts-ignore + // throw new Error(resp.error); + // } else { + // return resp as TCometPromptResponse; + // } + // } + // } else { + // const { data } = await this.cometAPI.post(`/prompt`, payload); + // return data as TCometPromptResponse; + // } + } + + async delete(): Promise { + await this.inferenceAPI.delete('/'); + } +} + +export default Inference; diff --git a/src/types.ts b/src/types.ts index 719f39e..1074ead 100644 --- a/src/types.ts +++ b/src/types.ts @@ -50,7 +50,6 @@ export interface IComet { ICometConversation & { messages: ICometMessage[]; stats: ICometConversationStats | null } >; } - export interface IndexInput { indexId?: string; id?: string; diff --git a/src/types/comet.ts b/src/types/comet.ts new file mode 100644 index 0000000..d09c66b --- /dev/null +++ b/src/types/comet.ts @@ -0,0 +1,176 @@ +// Comet Client +export interface IComet { + prompt: ( + payload: PromptPayload, + handleNewText?: (data: string) => void | Promise, + options?: PromptOptions + ) => Promise; + update: (payload: UpdateCometPayload) => Promise; + updateModel: (payload: UpdateModelPayload) => Promise; + getMessage: (payload: GetMessagePayload) => Promise; + takeConversationFeedback: (payload: ProvideMessageFeedbackPayload) => Promise; + deleteComet: () => Promise; + getCometInfo: () => Promise; + listSessions: (payload: ListSessionsPayload) => Promise; + getSession: (payload: GetSessionPayload) => Promise; + listConversations: ( + payload: ListConversationsPayload + ) => Promise< + Array< + ICometConversation & { + messages: M extends true ? ICometMessage[] : never; + stats: S extends true ? ICometConversationStats | null : never; + } + > + >; + getConversation: ( + payload: GetConversationPayload + ) => Promise< + ICometConversation & { messages: ICometMessage[]; stats: ICometConversationStats | null } + >; +} + +export interface PromptOptions { + useNativeFetch?: boolean; +} + +export interface PromptPayload { + input: string; + prompt_variables?: Record; + channel?: string; + visitorId?: string; + sessionId?: string; + stream: boolean; + configs?: { + max_tokens?: number; + temperature?: number; + top_p?: number; + presence_penalty?: number; + frequency_penalty?: number; + }; +} + +export interface ListConversationsPayload { + sessionId: string; + messages?: boolean; + stats?: boolean; +} + +export interface ListSessionsPayload { + userId?: string; + channel?: string; + visitorId?: string; +} +export interface GetSessionPayload { + sessionId: string; +} +export interface GetConversationPayload { + conversationId: string; +} +export interface GetMessagePayload { + messageId: string; +} + +export interface ProvideMessageFeedbackPayload { + conversationId: string; + vote?: boolean; + feedback?: string; + meta?: object; +} + +export interface UpdateCometPayload { + sectionsMatchThreshold?: number; + sectionMatchCount?: number; + name?: string; +} +export type TCometModelType = 'selfhosted' | 'thirdparty'; + +export interface UpdateModelPayload { + type?: TCometModelType; + details?: { name: string; vendor: 'openai' }; + configs?: object; +} + +export type CometAIModelType = 'text' | 'chat'; + +export interface ICometInfo { + projectId: string; + id: string; + createdAt: string; + updatedAt: string; + creatorId: string; + thirdPartyKeyId: string | null; + confluxId: string | null; + name: string; + configs: Record | null; + whitelistedDomains: string[]; + promptVariables: Record; + promptTemplate: string | null; + promptTokenLimit: number | null; + sectionsMatchThreshold: number | null; + sectionMatchCount: number | null; + contextTokenCutoff: number | null; + model: string; + modelVendor: string; + modelType: CometAIModelType; + conversationHistoryCutoff?: string; +} + +export interface ICometSession { + id: string; + channel: string; + metadata: Record; + userId: string; + visitorId: string | null; + createdAt: string; + updatedAt: string; +} + +export interface ICometConversationStats { + id: string; + feedback: string | null; + noResponse: boolean; + upvoted: boolean; + updatedAt: string; + downvoted: boolean; + processed: boolean; +} + +export interface ICometConversation { + id: string; + metadata: Record | null; + createdAt: string; + updatedAt: string; +} + +export type CometMessageAuthor = 'agent' | 'human' | 'system' | 'function'; + +export interface ICometMessage { + id: string; + text: string; + from: CometMessageAuthor; + meta: Record | null; + conversationId: string; + createdAt: string; +} + +export type TCometPromptResponse = { + generations: string[]; + meta: { + referencePaths?: string[]; + referencesWithSources?: { + path: string; + source_id: string; + }[]; + }; + usage: { + prompt_tokens: number; + completion_tokens: number; + }; + conversationId?: string; + sessionId?: string; +}; + +export type TCometPromptStreamResponseError = { + error: string; +}; diff --git a/src/types/inference.ts b/src/types/inference.ts new file mode 100644 index 0000000..6b5abe9 --- /dev/null +++ b/src/types/inference.ts @@ -0,0 +1,17 @@ +export interface IInference { + prompt: ( + payload: PromptPayload & Record, + handleNewText?: (data: string) => void | Promise, + options?: PromptOptions + ) => Promise; + getInferenceInfo: () => Promise; +} + +export interface PromptOptions { + useNativeFetch?: boolean; +} + +export interface PromptPayload { + prompt: string; + stream: boolean; +} diff --git a/src/utils/inference/io.ts b/src/utils/inference/io.ts new file mode 100644 index 0000000..c583153 --- /dev/null +++ b/src/utils/inference/io.ts @@ -0,0 +1,71 @@ +export const VLLMPromptConfigs = { + type: 'object', + properties: { + n: { + type: 'integer', + minimum: 1, + }, + best_of: { + type: ['integer', 'null'], + }, + presence_penalty: { + type: 'number', + minimum: -2.0, + maximum: 2.0, + }, + frequency_penalty: { + type: 'number', + minimum: -2.0, + maximum: 2.0, + }, + temperature: { + type: 'number', + minimum: 0.0, + }, + top_p: { + type: 'number', + minimum: 0.0, + maximum: 1.0, + }, + top_k: { + type: 'integer', + minimum: -1, + }, + use_beam_search: { + type: 'boolean', + }, + length_penalty: { + type: 'number', + minimum: 0.0, + }, + early_stopping: { + type: ['boolean', 'string'], + }, + stop: { + type: ['array', 'string', 'null'], + items: { + type: 'string', + }, + }, + stop_token_ids: { + type: 'array', + items: { + type: 'integer', + }, + }, + ignore_eos: { + type: 'boolean', + }, + max_tokens: { + type: 'integer', + minimum: 1, + }, + logprobs: { + type: ['integer', 'null'], + minimum: 0, + }, + skip_special_tokens: { + type: 'boolean', + }, + }, +}; diff --git a/src/utils/inference/stream.ts b/src/utils/inference/stream.ts new file mode 100644 index 0000000..c172874 --- /dev/null +++ b/src/utils/inference/stream.ts @@ -0,0 +1,170 @@ +// import { AxiosInstance } from 'axios'; +// import { API_V1_URL } from '../../constants'; +// import { PromptPayload, TCometPromptResponse } from 'types'; +// import { EventStreamContentType, fetchEventSource } from '@microsoft/fetch-event-source'; +// import { APIError } from 'error'; + +// export const streamPromptWithNativeFetch = ( +// cometId: string, +// apiKey: string, +// payload: PromptPayload, +// handleNewText?: (token: string) => void | Promise +// ) => { +// return new Promise((resolve, reject) => { +// (async () => { +// const response = await fetch(`${API_V1_URL}/comets/${cometId}/prompt`, { +// method: 'POST', +// body: JSON.stringify(payload), +// headers: { +// 'Content-Type': 'application/json', +// Authorization: `Bearer ${apiKey}`, +// Accept: 'text/plain, application/json', +// }, +// }); +// if (response.ok) { +// if (!response.body) return reject('No response body found.'); +// const reader = response.body.getReader(); +// let responsePrefixReceived = false; +// let responseText: string = ''; +// const textDecoder = new TextDecoder(); +// let done: boolean = false, +// value: Uint8Array | undefined; +// while (!done) { +// ({ done, value } = await reader.read()); +// if (value) { +// let chunk = textDecoder.decode(value); +// if (!responsePrefixReceived && chunk.includes(PROMPT_STREAM_RESPONSE_PREFIX)) { +// const splitChunks = chunk.split(PROMPT_STREAM_RESPONSE_PREFIX); +// if (splitChunks.length === 2) { +// handleNewText?.(splitChunks[0]); +// responseText = responseText.concat(splitChunks[1] ?? ''); +// } else return reject('Could not parse the response'); +// responsePrefixReceived = true; +// } else if (responsePrefixReceived) { +// responseText = responseText.concat(chunk); +// } else { +// handleNewText?.(chunk); +// } +// } +// if (done) { +// try { +// return resolve(JSON.parse(responseText)); +// } catch (e) { +// return reject('Could not parse the response'); +// } +// } +// } +// } else { +// if (response.headers.get('Content-Type') === 'application/json') +// return reject({ body: await response.json(), status: response.status }); +// else return reject(`Request failed.`); +// } +// })(); +// }); +// }; + +// export const streamPromptWithAxios = ( +// cometAPI: AxiosInstance, +// payload: PromptPayload, +// handleNewText?: (token: string) => void | Promise +// ) => { +// return new Promise((resolve, reject) => { +// (async () => { +// const { data: stream } = await cometAPI.post(`/prompt`, payload, { +// responseType: 'stream', +// }); +// let responsePrefixReceived = false; +// let responseText: string = ''; +// stream.on('data', (data: BinaryData) => { +// let chunk = data.toString(); + +// if (!responsePrefixReceived && chunk.includes(PROMPT_STREAM_RESPONSE_PREFIX)) { +// const splitChunks = chunk.split(PROMPT_STREAM_RESPONSE_PREFIX); +// if (splitChunks.length === 2) { +// handleNewText?.(splitChunks[0]); +// responseText = responseText.concat(splitChunks[1] ?? ''); +// } else return reject('Could not parse the response'); +// responsePrefixReceived = true; +// } else if (responsePrefixReceived) { +// responseText = responseText.concat(chunk); +// } else { +// handleNewText?.(chunk); +// } +// }); + +// stream.on('end', () => { +// return resolve(JSON.parse(responseText)); +// }); +// })(); +// }); +// }; + +// class ClientError extends Error {} +// // class FatalError extends Error {} + +// export const streamPromptWithEventStreaming = async ( +// cometId: string, +// apiKey: string, +// payload: PromptPayload, +// handleNewText?: (token: string) => void | Promise +// ): Promise => { +// try { +// let finalResponse: TCometPromptResponse; +// await fetchEventSource(`${API_V1_URL}/comets/${cometId}/prompt`, { +// method: 'POST', +// headers: { +// 'Content-Type': 'application/json', +// Accept: 'text/event-stream', +// Authorization: `Bearer ${apiKey}`, +// }, +// body: JSON.stringify(payload), +// async onopen(response) { +// const contentType = response.headers.get('content-type'); +// if (response.ok && contentType.includes(EventStreamContentType)) { +// return; // everything's good +// } else { +// if (contentType === 'application/json') { +// const body = await response.json(); +// throw new APIError({ status: response.status, message: body?.message }); +// } else { +// throw new APIError({ +// status: response.status, +// message: `Request failed with status code: ${response.status}`, +// }); +// } +// } +// }, +// onmessage(msg) { +// // if the server emits an error message, throw an exception +// // so it gets handled by the onerror callback below: +// if (msg.event === 'token') { +// handleNewText(msg.data); +// } else if (msg.event === 'response') { +// try { +// finalResponse = JSON.parse(msg.data); +// } catch (e) { +// throw new ClientError('Encountered error while parsing response into JSON.'); +// } +// } else if (msg.event === 'error') { +// } +// }, +// // onclose() { +// // // if the server closes the connection unexpectedly, retry: +// // throw new ClientError('Server closed connection.'); +// // }, +// onerror(err) { +// throw err; // rethrow to stop the operation +// }, + +// // signal: ctrl.signal, +// }); +// return finalResponse; +// } catch (e) { +// if (e instanceof APIError) { +// return { error: e.message }; +// } else { +// console.error('Unknown Error while prompting:', e); +// return { error: e?.message || 'Unknown Error' }; +// } +// } +// }; From 4852207ceed129aa3e893f560a48273841bec37a Mon Sep 17 00:00:00 2001 From: Ajeya Bhat Date: Fri, 29 Sep 2023 13:11:49 +0530 Subject: [PATCH 2/2] changeset --- .changeset/hot-vans-compete.md | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 .changeset/hot-vans-compete.md diff --git a/.changeset/hot-vans-compete.md b/.changeset/hot-vans-compete.md new file mode 100644 index 0000000..8ec42e6 --- /dev/null +++ b/.changeset/hot-vans-compete.md @@ -0,0 +1,5 @@ +--- +'outpostkit': patch +--- + +Inference Endpoints