From f95d1858f0195f8b780221f240d2ef93d089aa25 Mon Sep 17 00:00:00 2001 From: Lars Grammel Date: Sat, 30 Dec 2023 12:37:59 +0100 Subject: [PATCH] Extract AbstractOpenAICompletionModel. --- .../openai/AbstractOpenAICompletionModel.ts | 237 ++++++++++++++++++ .../openai/OpenAICompletionModel.ts | 235 +---------------- .../src/model-provider/openai/index.ts | 1 + 3 files changed, 247 insertions(+), 226 deletions(-) create mode 100644 packages/modelfusion/src/model-provider/openai/AbstractOpenAICompletionModel.ts diff --git a/packages/modelfusion/src/model-provider/openai/AbstractOpenAICompletionModel.ts b/packages/modelfusion/src/model-provider/openai/AbstractOpenAICompletionModel.ts new file mode 100644 index 000000000..eb4d109f7 --- /dev/null +++ b/packages/modelfusion/src/model-provider/openai/AbstractOpenAICompletionModel.ts @@ -0,0 +1,237 @@ +import { z } from "zod"; +import { FunctionOptions } from "../../core/FunctionOptions.js"; +import { ApiConfiguration } from "../../core/api/ApiConfiguration.js"; +import { callWithRetryAndThrottle } from "../../core/api/callWithRetryAndThrottle.js"; +import { + ResponseHandler, + createJsonResponseHandler, + postJsonToApi, +} from "../../core/api/postToApi.js"; +import { zodSchema } from "../../core/schema/ZodSchema.js"; +import { AbstractModel } from "../../model-function/AbstractModel.js"; +import { TextGenerationModelSettings } from "../../model-function/generate-text/TextGenerationModel.js"; +import { TextGenerationFinishReason } from "../../model-function/generate-text/TextGenerationResult.js"; +import { createEventSourceResponseHandler } from "../../util/streaming/createEventSourceResponseHandler.js"; +import { OpenAIApiConfiguration } from "./OpenAIApiConfiguration.js"; +import { failedOpenAICallResponseHandler } from "./OpenAIError.js"; + +export interface AbstractOpenAICompletionModelSettings + extends TextGenerationModelSettings { + api?: ApiConfiguration; + + model: string; + + suffix?: string; + temperature?: number; + topP?: number; + logprobs?: number; + echo?: boolean; + presencePenalty?: number; + frequencyPenalty?: number; + bestOf?: number; + logitBias?: Record; + seed?: number | null; + + isUserIdForwardingEnabled?: boolean; +} + +/** + * Abstract completion model that calls an API that is compatible with the OpenAI completions API. + * + * @see https://platform.openai.com/docs/api-reference/completions/create + */ +export abstract class AbstractOpenAICompletionModel< + SETTINGS extends AbstractOpenAICompletionModelSettings, +> extends AbstractModel { + constructor(settings: SETTINGS) { + super({ settings }); + } + + async callAPI( + prompt: string, + options: { + responseFormat: OpenAITextResponseFormatType; + } & FunctionOptions + ): Promise { + const api = this.settings.api ?? new OpenAIApiConfiguration(); + const user = this.settings.isUserIdForwardingEnabled + ? options.run?.userId + : undefined; + const abortSignal = options.run?.abortSignal; + const openaiResponseFormat = options.responseFormat; + + // empty arrays are not allowed for stop: + const stopSequences = + this.settings.stopSequences != null && + Array.isArray(this.settings.stopSequences) && + this.settings.stopSequences.length === 0 + ? undefined + : this.settings.stopSequences; + + return callWithRetryAndThrottle({ + retry: api.retry, + throttle: api.throttle, + call: async () => { + return postJsonToApi({ + url: api.assembleUrl("/completions"), + headers: api.headers, + body: { + stream: openaiResponseFormat.stream, + model: this.settings.model, + prompt, + suffix: this.settings.suffix, + max_tokens: this.settings.maxGenerationTokens, + temperature: this.settings.temperature, + top_p: this.settings.topP, + n: this.settings.numberOfGenerations, + logprobs: this.settings.logprobs, + echo: this.settings.echo, + stop: stopSequences, + seed: this.settings.seed, + presence_penalty: this.settings.presencePenalty, + frequency_penalty: this.settings.frequencyPenalty, + best_of: this.settings.bestOf, + logit_bias: this.settings.logitBias, + user, + }, + failedResponseHandler: failedOpenAICallResponseHandler, + successfulResponseHandler: openaiResponseFormat.handler, + abortSignal, + }); + }, + }); + } + + async doGenerateTexts(prompt: string, options?: FunctionOptions) { + const response = await this.callAPI(prompt, { + ...options, + responseFormat: OpenAITextResponseFormat.json, + }); + + return { + response, + textGenerationResults: response.choices.map((choice) => { + return { + finishReason: this.translateFinishReason(choice.finish_reason), + text: choice.text, + }; + }), + usage: { + promptTokens: response.usage.prompt_tokens, + completionTokens: response.usage.completion_tokens, + totalTokens: response.usage.total_tokens, + }, + }; + } + + private translateFinishReason( + finishReason: string | null | undefined + ): TextGenerationFinishReason { + switch (finishReason) { + case "stop": + return "stop"; + case "length": + return "length"; + case "content_filter": + return "content-filter"; + default: + return "unknown"; + } + } + + doStreamText(prompt: string, options?: FunctionOptions) { + return this.callAPI(prompt, { + ...options, + responseFormat: OpenAITextResponseFormat.deltaIterable, + }); + } + + extractTextDelta(delta: unknown) { + const chunk = delta as OpenAICompletionStreamChunk; + + const firstChoice = chunk.choices[0]; + + if (firstChoice.index > 0) { + return undefined; + } + + return chunk.choices[0].text; + } +} + +const OpenAICompletionResponseSchema = z.object({ + id: z.string(), + choices: z.array( + z.object({ + finish_reason: z + .enum(["stop", "length", "content_filter"]) + .optional() + .nullable(), + index: z.number(), + logprobs: z.nullable(z.any()), + text: z.string(), + }) + ), + created: z.number(), + model: z.string(), + system_fingerprint: z.string().optional(), + object: z.literal("text_completion"), + usage: z.object({ + prompt_tokens: z.number(), + completion_tokens: z.number(), + total_tokens: z.number(), + }), +}); + +export type OpenAICompletionResponse = z.infer< + typeof OpenAICompletionResponseSchema +>; + +const openaiCompletionStreamChunkSchema = zodSchema( + z.object({ + choices: z.array( + z.object({ + text: z.string(), + finish_reason: z + .enum(["stop", "length", "content_filter"]) + .optional() + .nullable(), + index: z.number(), + }) + ), + created: z.number(), + id: z.string(), + model: z.string(), + system_fingerprint: z.string().optional(), + object: z.literal("text_completion"), + }) +); + +type OpenAICompletionStreamChunk = + (typeof openaiCompletionStreamChunkSchema)["_type"]; + +export type OpenAITextResponseFormatType = { + stream: boolean; + handler: ResponseHandler; +}; + +export const OpenAITextResponseFormat = { + /** + * Returns the response as a JSON object. + */ + json: { + stream: false, + handler: createJsonResponseHandler(OpenAICompletionResponseSchema), + }, + + /** + * Returns an async iterable over the full deltas (all choices, including full current state at time of event) + * of the response stream. + */ + deltaIterable: { + stream: true, + handler: createEventSourceResponseHandler( + openaiCompletionStreamChunkSchema + ), + }, +}; diff --git a/packages/modelfusion/src/model-provider/openai/OpenAICompletionModel.ts b/packages/modelfusion/src/model-provider/openai/OpenAICompletionModel.ts index d3bd32f45..bae99b1c2 100644 --- a/packages/modelfusion/src/model-provider/openai/OpenAICompletionModel.ts +++ b/packages/modelfusion/src/model-provider/openai/OpenAICompletionModel.ts @@ -1,30 +1,19 @@ -import { z } from "zod"; -import { FunctionOptions } from "../../core/FunctionOptions.js"; -import { ApiConfiguration } from "../../core/api/ApiConfiguration.js"; -import { callWithRetryAndThrottle } from "../../core/api/callWithRetryAndThrottle.js"; -import { - ResponseHandler, - createJsonResponseHandler, - postJsonToApi, -} from "../../core/api/postToApi.js"; -import { zodSchema } from "../../core/schema/ZodSchema.js"; -import { AbstractModel } from "../../model-function/AbstractModel.js"; import { PromptTemplateTextStreamingModel } from "../../model-function/generate-text/PromptTemplateTextStreamingModel.js"; import { - TextGenerationModelSettings, TextStreamingModel, textGenerationModelProperties, } from "../../model-function/generate-text/TextGenerationModel.js"; import { TextGenerationPromptTemplate } from "../../model-function/generate-text/TextGenerationPromptTemplate.js"; -import { TextGenerationFinishReason } from "../../model-function/generate-text/TextGenerationResult.js"; import { chat, instruction, } from "../../model-function/generate-text/prompt-template/TextPromptTemplate.js"; import { countTokens } from "../../model-function/tokenize-text/countTokens.js"; -import { createEventSourceResponseHandler } from "../../util/streaming/createEventSourceResponseHandler.js"; -import { OpenAIApiConfiguration } from "./OpenAIApiConfiguration.js"; -import { failedOpenAICallResponseHandler } from "./OpenAIError.js"; +import { + AbstractOpenAICompletionModel, + AbstractOpenAICompletionModelSettings, + OpenAICompletionResponse, +} from "./AbstractOpenAICompletionModel.js"; import { TikTokenTokenizer } from "./TikTokenTokenizer.js"; /** @@ -186,27 +175,9 @@ export const calculateOpenAICompletionCostInMillicents = ({ ); }; -export interface OpenAICompletionCallSettings { - api?: ApiConfiguration; - - model: OpenAICompletionModelType; - - suffix?: string; - temperature?: number; - topP?: number; - logprobs?: number; - echo?: boolean; - presencePenalty?: number; - frequencyPenalty?: number; - bestOf?: number; - logitBias?: Record; - seed?: number | null; -} - export interface OpenAICompletionModelSettings - extends TextGenerationModelSettings, - Omit { - isUserIdForwardingEnabled?: boolean; + extends AbstractOpenAICompletionModelSettings { + model: OpenAICompletionModelType; } /** @@ -228,11 +199,11 @@ export interface OpenAICompletionModelSettings * ); */ export class OpenAICompletionModel - extends AbstractModel + extends AbstractOpenAICompletionModel implements TextStreamingModel { constructor(settings: OpenAICompletionModelSettings) { - super({ settings }); + super(settings); const modelInformation = getOpenAICompletionModelInformation( this.settings.model @@ -256,61 +227,6 @@ export class OpenAICompletionModel return countTokens(this.tokenizer, input); } - async callAPI( - prompt: string, - options: { - responseFormat: OpenAITextResponseFormatType; - } & FunctionOptions - ): Promise { - const api = this.settings.api ?? new OpenAIApiConfiguration(); - const user = this.settings.isUserIdForwardingEnabled - ? options.run?.userId - : undefined; - const abortSignal = options.run?.abortSignal; - const openaiResponseFormat = options.responseFormat; - - // empty arrays are not allowed for stop: - const stopSequences = - this.settings.stopSequences != null && - Array.isArray(this.settings.stopSequences) && - this.settings.stopSequences.length === 0 - ? undefined - : this.settings.stopSequences; - - return callWithRetryAndThrottle({ - retry: api.retry, - throttle: api.throttle, - call: async () => { - return postJsonToApi({ - url: api.assembleUrl("/completions"), - headers: api.headers, - body: { - stream: openaiResponseFormat.stream, - model: this.settings.model, - prompt, - suffix: this.settings.suffix, - max_tokens: this.settings.maxGenerationTokens, - temperature: this.settings.temperature, - top_p: this.settings.topP, - n: this.settings.numberOfGenerations, - logprobs: this.settings.logprobs, - echo: this.settings.echo, - stop: stopSequences, - seed: this.settings.seed, - presence_penalty: this.settings.presencePenalty, - frequency_penalty: this.settings.frequencyPenalty, - best_of: this.settings.bestOf, - logit_bias: this.settings.logitBias, - user, - }, - failedResponseHandler: failedOpenAICallResponseHandler, - successfulResponseHandler: openaiResponseFormat.handler, - abortSignal, - }); - }, - }); - } - get settingsForEvent(): Partial { const eventSettingProperties: Array = [ ...textGenerationModelProperties, @@ -334,62 +250,6 @@ export class OpenAICompletionModel ); } - async doGenerateTexts(prompt: string, options?: FunctionOptions) { - const response = await this.callAPI(prompt, { - ...options, - responseFormat: OpenAITextResponseFormat.json, - }); - - return { - response, - textGenerationResults: response.choices.map((choice) => { - return { - finishReason: this.translateFinishReason(choice.finish_reason), - text: choice.text, - }; - }), - usage: { - promptTokens: response.usage.prompt_tokens, - completionTokens: response.usage.completion_tokens, - totalTokens: response.usage.total_tokens, - }, - }; - } - - private translateFinishReason( - finishReason: string | null | undefined - ): TextGenerationFinishReason { - switch (finishReason) { - case "stop": - return "stop"; - case "length": - return "length"; - case "content_filter": - return "content-filter"; - default: - return "unknown"; - } - } - - doStreamText(prompt: string, options?: FunctionOptions) { - return this.callAPI(prompt, { - ...options, - responseFormat: OpenAITextResponseFormat.deltaIterable, - }); - } - - extractTextDelta(delta: unknown) { - const chunk = delta as OpenAICompletionStreamChunk; - - const firstChoice = chunk.choices[0]; - - if (firstChoice.index > 0) { - return undefined; - } - - return chunk.choices[0].text; - } - /** * Returns this model with an instruction prompt template. */ @@ -429,80 +289,3 @@ export class OpenAICompletionModel ) as this; } } - -const OpenAICompletionResponseSchema = z.object({ - id: z.string(), - choices: z.array( - z.object({ - finish_reason: z - .enum(["stop", "length", "content_filter"]) - .optional() - .nullable(), - index: z.number(), - logprobs: z.nullable(z.any()), - text: z.string(), - }) - ), - created: z.number(), - model: z.string(), - system_fingerprint: z.string().optional(), - object: z.literal("text_completion"), - usage: z.object({ - prompt_tokens: z.number(), - completion_tokens: z.number(), - total_tokens: z.number(), - }), -}); - -export type OpenAICompletionResponse = z.infer< - typeof OpenAICompletionResponseSchema ->; - -const openaiCompletionStreamChunkSchema = zodSchema( - z.object({ - choices: z.array( - z.object({ - text: z.string(), - finish_reason: z - .enum(["stop", "length", "content_filter"]) - .optional() - .nullable(), - index: z.number(), - }) - ), - created: z.number(), - id: z.string(), - model: z.string(), - system_fingerprint: z.string().optional(), - object: z.literal("text_completion"), - }) -); - -type OpenAICompletionStreamChunk = - (typeof openaiCompletionStreamChunkSchema)["_type"]; - -export type OpenAITextResponseFormatType = { - stream: boolean; - handler: ResponseHandler; -}; - -export const OpenAITextResponseFormat = { - /** - * Returns the response as a JSON object. - */ - json: { - stream: false, - handler: createJsonResponseHandler(OpenAICompletionResponseSchema), - }, - - /** - * Returns an async iterable over the full deltas (all choices, including full current state at time of event) - * of the response stream. - */ - deltaIterable: { - stream: true, - handler: createEventSourceResponseHandler( - openaiCompletionStreamChunkSchema - ), - }, -}; diff --git a/packages/modelfusion/src/model-provider/openai/index.ts b/packages/modelfusion/src/model-provider/openai/index.ts index f004a18f5..a7406fc4e 100644 --- a/packages/modelfusion/src/model-provider/openai/index.ts +++ b/packages/modelfusion/src/model-provider/openai/index.ts @@ -1,4 +1,5 @@ export * from "./AbstractOpenAIChatModel.js"; +export * from "./AbstractOpenAICompletionModel.js"; export * from "./AzureOpenAIApiConfiguration.js"; export * from "./OpenAIApiConfiguration.js"; export * from "./OpenAIChatMessage.js";