Skip to content

Commit

Permalink
Extract AbstractOpenAICompletionModel.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Dec 30, 2023
1 parent b81923b commit f95d185
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 226 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
import { z } from "zod";
import { FunctionOptions } from "../../core/FunctionOptions.js";
import { ApiConfiguration } from "../../core/api/ApiConfiguration.js";
import { callWithRetryAndThrottle } from "../../core/api/callWithRetryAndThrottle.js";
import {
ResponseHandler,
createJsonResponseHandler,
postJsonToApi,
} from "../../core/api/postToApi.js";
import { zodSchema } from "../../core/schema/ZodSchema.js";
import { AbstractModel } from "../../model-function/AbstractModel.js";
import { TextGenerationModelSettings } from "../../model-function/generate-text/TextGenerationModel.js";
import { TextGenerationFinishReason } from "../../model-function/generate-text/TextGenerationResult.js";
import { createEventSourceResponseHandler } from "../../util/streaming/createEventSourceResponseHandler.js";
import { OpenAIApiConfiguration } from "./OpenAIApiConfiguration.js";
import { failedOpenAICallResponseHandler } from "./OpenAIError.js";

export interface AbstractOpenAICompletionModelSettings
extends TextGenerationModelSettings {
api?: ApiConfiguration;

model: string;

suffix?: string;
temperature?: number;
topP?: number;
logprobs?: number;
echo?: boolean;
presencePenalty?: number;
frequencyPenalty?: number;
bestOf?: number;
logitBias?: Record<number, number>;
seed?: number | null;

isUserIdForwardingEnabled?: boolean;
}

/**
* Abstract completion model that calls an API that is compatible with the OpenAI completions API.
*
* @see https://platform.openai.com/docs/api-reference/completions/create
*/
export abstract class AbstractOpenAICompletionModel<
SETTINGS extends AbstractOpenAICompletionModelSettings,
> extends AbstractModel<SETTINGS> {
constructor(settings: SETTINGS) {
super({ settings });
}

async callAPI<RESULT>(
prompt: string,
options: {
responseFormat: OpenAITextResponseFormatType<RESULT>;
} & FunctionOptions
): Promise<RESULT> {
const api = this.settings.api ?? new OpenAIApiConfiguration();
const user = this.settings.isUserIdForwardingEnabled
? options.run?.userId
: undefined;
const abortSignal = options.run?.abortSignal;
const openaiResponseFormat = options.responseFormat;

// empty arrays are not allowed for stop:
const stopSequences =
this.settings.stopSequences != null &&
Array.isArray(this.settings.stopSequences) &&
this.settings.stopSequences.length === 0
? undefined
: this.settings.stopSequences;

return callWithRetryAndThrottle({
retry: api.retry,
throttle: api.throttle,
call: async () => {
return postJsonToApi({
url: api.assembleUrl("/completions"),
headers: api.headers,
body: {
stream: openaiResponseFormat.stream,
model: this.settings.model,
prompt,
suffix: this.settings.suffix,
max_tokens: this.settings.maxGenerationTokens,
temperature: this.settings.temperature,
top_p: this.settings.topP,
n: this.settings.numberOfGenerations,
logprobs: this.settings.logprobs,
echo: this.settings.echo,
stop: stopSequences,
seed: this.settings.seed,
presence_penalty: this.settings.presencePenalty,
frequency_penalty: this.settings.frequencyPenalty,
best_of: this.settings.bestOf,
logit_bias: this.settings.logitBias,
user,
},
failedResponseHandler: failedOpenAICallResponseHandler,
successfulResponseHandler: openaiResponseFormat.handler,
abortSignal,
});
},
});
}

async doGenerateTexts(prompt: string, options?: FunctionOptions) {
const response = await this.callAPI(prompt, {
...options,
responseFormat: OpenAITextResponseFormat.json,
});

return {
response,
textGenerationResults: response.choices.map((choice) => {
return {
finishReason: this.translateFinishReason(choice.finish_reason),
text: choice.text,
};
}),
usage: {
promptTokens: response.usage.prompt_tokens,
completionTokens: response.usage.completion_tokens,
totalTokens: response.usage.total_tokens,
},
};
}

private translateFinishReason(
finishReason: string | null | undefined
): TextGenerationFinishReason {
switch (finishReason) {
case "stop":
return "stop";
case "length":
return "length";
case "content_filter":
return "content-filter";
default:
return "unknown";
}
}

doStreamText(prompt: string, options?: FunctionOptions) {
return this.callAPI(prompt, {
...options,
responseFormat: OpenAITextResponseFormat.deltaIterable,
});
}

extractTextDelta(delta: unknown) {
const chunk = delta as OpenAICompletionStreamChunk;

const firstChoice = chunk.choices[0];

if (firstChoice.index > 0) {
return undefined;
}

return chunk.choices[0].text;
}
}

const OpenAICompletionResponseSchema = z.object({
id: z.string(),
choices: z.array(
z.object({
finish_reason: z
.enum(["stop", "length", "content_filter"])
.optional()
.nullable(),
index: z.number(),
logprobs: z.nullable(z.any()),
text: z.string(),
})
),
created: z.number(),
model: z.string(),
system_fingerprint: z.string().optional(),
object: z.literal("text_completion"),
usage: z.object({
prompt_tokens: z.number(),
completion_tokens: z.number(),
total_tokens: z.number(),
}),
});

export type OpenAICompletionResponse = z.infer<
typeof OpenAICompletionResponseSchema
>;

const openaiCompletionStreamChunkSchema = zodSchema(
z.object({
choices: z.array(
z.object({
text: z.string(),
finish_reason: z
.enum(["stop", "length", "content_filter"])
.optional()
.nullable(),
index: z.number(),
})
),
created: z.number(),
id: z.string(),
model: z.string(),
system_fingerprint: z.string().optional(),
object: z.literal("text_completion"),
})
);

type OpenAICompletionStreamChunk =
(typeof openaiCompletionStreamChunkSchema)["_type"];

export type OpenAITextResponseFormatType<T> = {
stream: boolean;
handler: ResponseHandler<T>;
};

export const OpenAITextResponseFormat = {
/**
* Returns the response as a JSON object.
*/
json: {
stream: false,
handler: createJsonResponseHandler(OpenAICompletionResponseSchema),
},

/**
* Returns an async iterable over the full deltas (all choices, including full current state at time of event)
* of the response stream.
*/
deltaIterable: {
stream: true,
handler: createEventSourceResponseHandler(
openaiCompletionStreamChunkSchema
),
},
};
Loading

1 comment on commit f95d185

@vercel
Copy link

@vercel vercel bot commented on f95d185 Dec 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please sign in to comment.