Skip to content

Commit

Permalink
Refactor prompt validation. (#1247)
Browse files Browse the repository at this point in the history
  • Loading branch information
lgrammel authored Mar 28, 2024
1 parent a54ea77 commit fafc8d5
Show file tree
Hide file tree
Showing 7 changed files with 180 additions and 134 deletions.
48 changes: 29 additions & 19 deletions packages/core/core/generate-object/generate-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import {
import { TokenUsage, calculateTokenUsage } from '../generate-text/token-usage';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getInputFormat } from '../prompt/get-input-format';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { retryWithExponentialBackoff } from '../util/retry-with-exponential-backoff';
Expand Down Expand Up @@ -96,19 +96,21 @@ Default and recommended: 'auto' (best mode for the model).

switch (mode) {
case 'json': {
const generateResult = await retry(() =>
model.doGenerate({
const validatedPrompt = getValidatedPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
});

const generateResult = await retry(() => {
return model.doGenerate({
mode: { type: 'object-json' },
...prepareCallSettings(settings),
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
}),
);
});
});

if (generateResult.text === undefined) {
throw new NoTextGeneratedError();
Expand All @@ -123,16 +125,18 @@ Default and recommended: 'auto' (best mode for the model).
}

case 'grammar': {
const validatedPrompt = getValidatedPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
});

const generateResult = await retry(() =>
model.doGenerate({
mode: { type: 'object-grammar', schema: jsonSchema },
...settings,
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
}),
);
Expand All @@ -150,6 +154,12 @@ Default and recommended: 'auto' (best mode for the model).
}

case 'tool': {
const validatedPrompt = getValidatedPrompt({
system,
prompt,
messages,
});

const generateResult = await retry(() =>
model.doGenerate({
mode: {
Expand All @@ -162,8 +172,8 @@ Default and recommended: 'auto' (best mode for the model).
},
},
...settings,
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({ system, prompt, messages }),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
}),
);
Expand Down
40 changes: 25 additions & 15 deletions packages/core/core/generate-object/stream-object.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import {
} from '../../spec';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getInputFormat } from '../prompt/get-input-format';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import {
Expand Down Expand Up @@ -99,15 +99,17 @@ Default and recommended: 'auto' (best mode for the model).

switch (mode) {
case 'json': {
const validatedPrompt = getValidatedPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
});

callOptions = {
mode: { type: 'object-json' },
...prepareCallSettings(settings),
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
};

Expand All @@ -128,15 +130,17 @@ Default and recommended: 'auto' (best mode for the model).
}

case 'grammar': {
const validatedPrompt = getValidatedPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
});

callOptions = {
mode: { type: 'object-grammar', schema: jsonSchema },
...settings,
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system: injectJsonSchemaIntoSystem({ system, schema: jsonSchema }),
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
};

Expand All @@ -157,6 +161,12 @@ Default and recommended: 'auto' (best mode for the model).
}

case 'tool': {
const validatedPrompt = getValidatedPrompt({
system,
prompt,
messages,
});

callOptions = {
mode: {
type: 'object-tool',
Expand All @@ -168,8 +178,8 @@ Default and recommended: 'auto' (best mode for the model).
},
},
...settings,
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({ system, prompt, messages }),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
};

Expand Down
19 changes: 8 additions & 11 deletions packages/core/core/generate-text/generate-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import {
} from '../../spec';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getInputFormat } from '../prompt/get-input-format';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { ExperimentalTool } from '../tool/tool';
Expand Down Expand Up @@ -75,8 +75,9 @@ The tools that the model can call. The model needs to support calling tools.
tools?: TOOLS;
}): Promise<GenerateTextResult<TOOLS>> {
const retry = retryWithExponentialBackoff({ maxRetries });
const modelResponse = await retry(() =>
model.doGenerate({
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const modelResponse = await retry(() => {
return model.doGenerate({
mode: {
type: 'regular',
tools:
Expand All @@ -90,15 +91,11 @@ The tools that the model can call. The model needs to support calling tools.
})),
},
...prepareCallSettings(settings),
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system,
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
}),
);
});
});

// parse tool calls:
const toolCalls: ToToolCallArray<TOOLS> = [];
Expand Down
11 changes: 4 additions & 7 deletions packages/core/core/generate-text/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import {
} from '../../streams';
import { CallSettings } from '../prompt/call-settings';
import { convertToLanguageModelPrompt } from '../prompt/convert-to-language-model-prompt';
import { getInputFormat } from '../prompt/get-input-format';
import { getValidatedPrompt } from '../prompt/get-validated-prompt';
import { prepareCallSettings } from '../prompt/prepare-call-settings';
import { Prompt } from '../prompt/prompt';
import { ExperimentalTool } from '../tool';
Expand Down Expand Up @@ -82,6 +82,7 @@ The tools that the model can call. The model needs to support calling tools.
tools?: TOOLS;
}): Promise<StreamTextResult<TOOLS>> {
const retry = retryWithExponentialBackoff({ maxRetries });
const validatedPrompt = getValidatedPrompt({ system, prompt, messages });
const { stream, warnings } = await retry(() =>
model.doStream({
mode: {
Expand All @@ -97,12 +98,8 @@ The tools that the model can call. The model needs to support calling tools.
})),
},
...prepareCallSettings(settings),
inputFormat: getInputFormat({ prompt, messages }),
prompt: convertToLanguageModelPrompt({
system,
prompt,
messages,
}),
inputFormat: validatedPrompt.type,
prompt: convertToLanguageModelPrompt(validatedPrompt),
abortSignal,
}),
);
Expand Down
133 changes: 67 additions & 66 deletions packages/core/core/prompt/convert-to-language-model-prompt.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,88 +5,89 @@ import {
LanguageModelV1TextPart,
} from '../../spec';
import { convertDataContentToUint8Array } from './data-content';
import { Prompt } from './prompt';

export function convertToLanguageModelPrompt({
system,
prompt,
messages,
}: Prompt): LanguageModelV1Prompt {
if (prompt == null && messages == null) {
throw new Error('prompt or messages must be defined');
}

if (prompt != null && messages != null) {
throw new Error('prompt and messages cannot be defined at the same time');
}
import { ValidatedPrompt } from './get-validated-prompt';

export function convertToLanguageModelPrompt(
prompt: ValidatedPrompt,
): LanguageModelV1Prompt {
const languageModelMessages: LanguageModelV1Prompt = [];

if (system != null) {
languageModelMessages.push({ role: 'system', content: system });
if (prompt.system != null) {
languageModelMessages.push({ role: 'system', content: prompt.system });
}

if (typeof prompt === 'string') {
languageModelMessages.push({
role: 'user',
content: [{ type: 'text', text: prompt }],
});
} else {
messages = messages!; // it's not null because of the check above
switch (prompt.type) {
case 'prompt': {
languageModelMessages.push({
role: 'user',
content: [{ type: 'text', text: prompt.prompt }],
});
break;
}

case 'messages': {
languageModelMessages.push(
...prompt.messages.map((message): LanguageModelV1Message => {
switch (message.role) {
case 'user': {
if (typeof message.content === 'string') {
return {
role: 'user',
content: [{ type: 'text', text: message.content }],
};
}

languageModelMessages.push(
...messages.map((message): LanguageModelV1Message => {
switch (message.role) {
case 'user': {
if (typeof message.content === 'string') {
return {
role: 'user',
content: [{ type: 'text', text: message.content }],
};
}
content: message.content.map(
(
part,
): LanguageModelV1TextPart | LanguageModelV1ImagePart => {
switch (part.type) {
case 'text': {
return part;
}

return {
role: 'user',
content: message.content.map(
(part): LanguageModelV1TextPart | LanguageModelV1ImagePart => {
switch (part.type) {
case 'text': {
return part;
case 'image': {
return {
type: 'image',
image:
part.image instanceof URL
? part.image
: convertDataContentToUint8Array(part.image),
mimeType: part.mimeType,
};
}
}
},
),
};
}

case 'image': {
return {
type: 'image',
image:
part.image instanceof URL
? part.image
: convertDataContentToUint8Array(part.image),
mimeType: part.mimeType,
};
}
}
},
),
};
}
case 'assistant': {
if (typeof message.content === 'string') {
return {
role: 'assistant',
content: [{ type: 'text', text: message.content }],
};
}

case 'assistant': {
if (typeof message.content === 'string') {
return {
role: 'assistant',
content: [{ type: 'text', text: message.content }],
};
return { role: 'assistant', content: message.content };
}

return { role: 'assistant', content: message.content };
case 'tool': {
return message;
}
}
}),
);
break;
}

case 'tool': {
return message;
}
}
}),
);
default: {
const _exhaustiveCheck: never = prompt;
throw new Error(`Unsupported prompt type: ${_exhaustiveCheck}`);
}
}

return languageModelMessages;
Expand Down
Loading

0 comments on commit fafc8d5

Please sign in to comment.