Skip to content

Commit

Permalink
Generalize text generation numberOfGenerations.
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel committed Dec 16, 2023
1 parent e89b6b6 commit 26684f9
Show file tree
Hide file tree
Showing 21 changed files with 346 additions and 441 deletions.
1 change: 1 addition & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"llamacpp",
"Lmnt",
"logit",
"logprobs",
"Millicents",
"mirostat",
"modelfusion",
Expand Down
21 changes: 21 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,26 @@
# Changelog

## Unreleased

- You can now specify `numberOfGenerations` on text generation models and access multiple generations by using the `fullResponse: true` option. Example:

```ts
// generate 2 texts:
const { texts } = await generateText(
openai.CompletionTextGenerator({
model: "gpt-3.5-turbo-instruct",
numberOfGenerations: 2,
maxCompletionTokens: 1000,
}),
"Write a short story about a robot learning to love:\n\n",
{ fullResponse: true }
);
```

### Added

- **breaking change**: Text generation models now use a generalized `numberOfGenerations` parameter (instead of model specific parameters) to specify the number of generations.

## v0.98.0 - 2023-12-16

### Changed
Expand Down
20 changes: 20 additions & 0 deletions docs/guide/function/generate-text.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,14 @@ You can use [prompt templates](#prompt-format) to change the prompt template of
The different [TextGenerationModel](/api/interfaces/TextGenerationModel) implementations (see [available providers](#available-providers)) share some common settings:

- **maxCompletionTokens**: The maximum number of tokens to generate, or undefined to generate an unlimited number of tokens.
- **numberOfGenerations**: The number of completions to generate.
- **stopSequences**: An array of text sequences that will stop the text generation when they are generated. The sequences are not included in the generated text. The default is an empty array.
- **trimWhitespace**: When true (default), the leading and trailing white space and line terminator characters are removed from the generated text. Only applies to `generateText`.

:::note
Not all models support all common settings. E.g., the `numberOfGenerations` setting is not supported by some local models.
:::

In addition to these common settings, each model exposes its own settings.
The settings can be set in the constructor of the model, or in the `withSettings` method.

Expand Down Expand Up @@ -51,6 +56,21 @@ const text = await generateText(openai.ChatTextGenerator(/* ... */), [
]);
```

#### Example: Generate multiple completions

```ts
import { generateText, openai } from "modelfusion";

const { texts } = await generateText(
openai.CompletionTextGenerator({
model: "gpt-3.5-turbo-instruct",
numberOfGenerations: 2,
}),
"Write a short story about a robot learning to love:",
{ fullResponse: true }
);
```

#### Example: OpenAI chat model with multi-modal input

Multi-modal vision models such as GPT 4 Vision can process images as part of the prompt.
Expand Down
2 changes: 1 addition & 1 deletion docs/guide/function/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ Model functions return rich results that include the original response and metad

```ts
// access the full response (needs to be typed) and the metadata:
const { text, response, metadata } = await generateText(
const { text, texts, response, metadata } = await generateText(
openai.CompletionTextGenerator({
model: "gpt-3.5-turbo-instruct",
maxCompletionTokens: 1000,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ dotenv.config();

async function main() {
// access the full response and the metadata:
const { text, response, metadata } = await generateText(
const { text, texts, response, metadata } = await generateText(
openai.CompletionTextGenerator({
model: "gpt-3.5-turbo-instruct",
numberOfGenerations: 2,
maxCompletionTokens: 1000,
n: 2, // generate 2 completions
}),
"Write a short story about a robot learning to love:\n\n",
{ fullResponse: true }
Expand All @@ -20,6 +20,7 @@ async function main() {
// cast to the response type:
for (const choice of (response as OpenAICompletionResponse).choices) {
console.log(choice.text);
console.log(choice.finish_reason);
console.log();
console.log();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import dotenv from "dotenv";
import { generateText, openai } from "modelfusion";

dotenv.config();

async function main() {
const { texts } = await generateText(
openai.CompletionTextGenerator({
model: "gpt-3.5-turbo-instruct",
numberOfGenerations: 2,
maxCompletionTokens: 1000,
}),
"Write a short story about a robot learning to love:\n\n",
{ fullResponse: true }
);

// multiple generations:
for (const text of texts) {
console.log(text);
console.log();
console.log();
}
}

main().catch(console.error);
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ async function main() {
openai.ChatTextGenerator({
model: "gpt-3.5-turbo",
maxCompletionTokens: 1000,
n: 2,
numberOfGenerations: 2,
}),
[
OpenAIChatMessage.system("You are a story writer. Write a story about:"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,9 @@ export class PromptTemplateTextGenerationModel<
: (prompt: PROMPT) => PromiseLike<number>;
}

doGenerateText(prompt: PROMPT, options?: FunctionOptions) {
doGenerateTexts(prompt: PROMPT, options?: FunctionOptions) {
const mappedPrompt = this.promptTemplate.format(prompt);
return this.model.doGenerateText(mappedPrompt, options);
return this.model.doGenerateTexts(mappedPrompt, options);
}

get settingsForEvent(): Partial<SETTINGS> {
Expand Down
35 changes: 31 additions & 4 deletions src/model-function/generate-text/TextGenerationModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,47 @@ import { TextGenerationPromptTemplate } from "./TextGenerationPromptTemplate.js"

export interface TextGenerationModelSettings extends ModelSettings {
/**
* Maximum number of tokens to generate.
* Specifies the maximum number of tokens (words, punctuation, parts of words) that the model can generate in a single response.
* It helps to control the length of the output.
*
* Does nothing if the model does not support this setting.
*
* Example: `maxCompletionTokens: 1000`
*/
maxCompletionTokens?: number | undefined;

/**
* Stop sequences to use. Stop sequences are not included in the generated text.
* Stop sequences to use.
* Stop sequences are an array of strings or a single string that the model will recognize as end-of-text indicators.
* The model stops generating more content when it encounters any of these strings.
* This is particularly useful in scripted or formatted text generation, where a specific end point is required.
* Stop sequences not included in the generated text.
*
* Does nothing if the model does not support this setting.
*
* Example: `stopSequences: ['\n', 'END']`
*/
stopSequences?: string[] | undefined;

/**
* Number of texts to generate.
*
* Specifies the number of responses or completions the model should generate for a given prompt.
* This is useful when you need multiple different outputs or ideas for a single prompt.
* The model will generate 'n' distinct responses, each based on the same initial prompt.
* In a streaming model this will result in both responses streamed back in real time.
*
* Does nothing if the model does not support this setting.
*
* Example: `numberOfGenerations: 3` // The model will produce 3 different responses.
*/
numberOfGenerations?: number;

/**
* When true, the leading and trailing white space and line terminator characters
* are removed from the generated text.
*
* Default: true.
*/
trimWhitespace?: boolean;
}
Expand Down Expand Up @@ -49,12 +76,12 @@ export interface TextGenerationModel<
| ((prompt: PROMPT) => PromiseLike<number>)
| undefined;

doGenerateText(
doGenerateTexts(
prompt: PROMPT,
options?: FunctionOptions
): PromiseLike<{
response: unknown;
text: string;
texts: string[];
usage?: {
promptTokens: number;
completionTokens: number;
Expand Down
31 changes: 25 additions & 6 deletions src/model-function/generate-text/generateText.ts
Original file line number Diff line number Diff line change
Expand Up @@ -36,36 +36,55 @@ export async function generateText<PROMPT>(
model: TextGenerationModel<PROMPT, TextGenerationModelSettings>,
prompt: PROMPT,
options: FunctionOptions & { fullResponse: true }
): Promise<{ text: string; response: unknown; metadata: ModelCallMetadata }>;
): Promise<{
text: string;
texts: string[];
response: unknown;
metadata: ModelCallMetadata;
}>;
export async function generateText<PROMPT>(
model: TextGenerationModel<PROMPT, TextGenerationModelSettings>,
prompt: PROMPT,
options?: FunctionOptions & { fullResponse?: boolean }
): Promise<
string | { text: string; response: unknown; metadata: ModelCallMetadata }
| string
| {
text: string;
texts: string[];
response: unknown;
metadata: ModelCallMetadata;
}
> {
const fullResponse = await executeStandardCall({
functionType: "generate-text",
input: prompt,
model,
options,
generateResponse: async (options) => {
const result = await model.doGenerateText(prompt, options);
const result = await model.doGenerateTexts(prompt, options);
const shouldTrimWhitespace = model.settings.trimWhitespace ?? true;

const texts = shouldTrimWhitespace
? result.texts.map((text) => text.trim())
: result.texts;

return {
response: result.response,
extractedValue: shouldTrimWhitespace ? result.text.trim() : result.text,
extractedValue: texts,
usage: result.usage,
};
},
});

const texts = fullResponse.value;
const text = texts[0];

return options?.fullResponse
? {
text: fullResponse.value,
text,
texts,
response: fullResponse.response,
metadata: fullResponse.metadata,
}
: fullResponse.value;
: text;
}
85 changes: 27 additions & 58 deletions src/model-provider/anthropic/AnthropicTextGenerationModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,34 @@ export class AnthropicTextGenerationModel
responseFormat: AnthropicTextGenerationResponseFormatType<RESPONSE>;
} & FunctionOptions
): Promise<RESPONSE> {
const api = this.settings.api ?? new AnthropicApiConfiguration();
const responseFormat = options.responseFormat;
const abortSignal = options.run?.abortSignal;
const userId = this.settings.userId;

return callWithRetryAndThrottle({
retry: this.settings.api?.retry,
throttle: this.settings.api?.throttle,
call: async () =>
callAnthropicTextGenerationAPI({
...this.settings,

stopSequences: this.settings.stopSequences,
maxTokens: this.settings.maxCompletionTokens,

abortSignal: options.run?.abortSignal,
responseFormat: options.responseFormat,
prompt,
}),
call: async () => {
return postJsonToApi({
url: api.assembleUrl(`/complete`),
headers: api.headers,
body: {
model: this.settings.model,
prompt,
stream: responseFormat.stream,
max_tokens_to_sample: this.settings.maxCompletionTokens,
temperature: this.settings.temperature,
top_k: this.settings.topK,
top_p: this.settings.topP,
stop_sequences: this.settings.stopSequences,
metadata: userId != null ? { user_id: userId } : undefined,
},
failedResponseHandler: failedAnthropicCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
},
});
}

Expand All @@ -122,15 +136,15 @@ export class AnthropicTextGenerationModel
);
}

async doGenerateText(prompt: string, options?: FunctionOptions) {
async doGenerateTexts(prompt: string, options?: FunctionOptions) {
const response = await this.callAPI(prompt, {
...options,
responseFormat: AnthropicTextGenerationResponseFormat.json,
});

return {
response,
text: response.completion,
texts: [response.completion],
};
}

Expand Down Expand Up @@ -200,51 +214,6 @@ export type AnthropicTextGenerationResponse = z.infer<
typeof anthropicTextGenerationResponseSchema
>;

async function callAnthropicTextGenerationAPI<RESPONSE>({
api = new AnthropicApiConfiguration(),
abortSignal,
responseFormat,
model,
prompt,
maxTokens,
stopSequences,
temperature,
topK,
topP,
userId,
}: {
api?: ApiConfiguration;
abortSignal?: AbortSignal;
responseFormat: AnthropicTextGenerationResponseFormatType<RESPONSE>;
model: AnthropicTextGenerationModelType;
prompt: string;
maxTokens?: number;
stopSequences?: string[];
temperature?: number;
topP?: number;
topK?: number;
userId?: number;
}): Promise<RESPONSE> {
return postJsonToApi({
url: api.assembleUrl(`/complete`),
headers: api.headers,
body: {
model,
prompt,
stream: responseFormat.stream,
max_tokens_to_sample: maxTokens,
temperature,
top_k: topK,
top_p: topP,
stop_sequences: stopSequences,
metadata: userId != null ? { user_id: userId } : undefined,
},
failedResponseHandler: failedAnthropicCallResponseHandler,
successfulResponseHandler: responseFormat.handler,
abortSignal,
});
}

const anthropicTextStreamingResponseSchema = new ZodSchema(
z.object({
completion: z.string(),
Expand Down
Loading

0 comments on commit 26684f9

Please sign in to comment.