Skip to content

Commit

Permalink
review comments + pulling changes from another PR
Browse files Browse the repository at this point in the history
  • Loading branch information
farzadab committed Jul 19, 2023
1 parent ec10b79 commit 3c4d652
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 32 deletions.
67 changes: 42 additions & 25 deletions packages/ai-jsx/src/batteries/constrained-output.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@ import {
ModelPropsWithChildren,
} from '../core/completion.js';
import yaml from 'js-yaml';
import { AIJSXError, ErrorCode } from '../core/errors.js';
import { AIJSXError, ErrorCode, ErrorBlame } from '../core/errors.js';
import { Jsonifiable } from 'type-fest';
import z from 'zod';
import { zodToJsonSchema } from 'zod-to-json-schema';
import untruncateJson from 'untruncate-json';
import { patchedUntruncateJson } from '../lib/util.js';

export type ObjectCompletion = ModelPropsWithChildren & {
/** Validators are used to ensure that the final object looks as expected. */
Expand All @@ -27,6 +28,12 @@ export type ObjectCompletion = ModelPropsWithChildren & {
*
* @note To match OpenAI function definition specs, the schema must be a Zod object.
* Arrays and other types should be wrapped in a top-level object in order to be used.
*
* For example, to describe a list of strings, the following is not accepted:
* `const schema: z.Schema = z.array(z.string())`
*
* Instead, you can wrap it in an object like so:
* `const schema: z.ZodObject = z.object({ arr: z.array(z.string()) })`
*/
schema?: z.ZodObject<any>;
/** Any output example to be shown to the model. */
Expand Down Expand Up @@ -88,22 +95,19 @@ export async function* JsonChatCompletion(
{ schema, ...props }: Omit<TypedObjectCompletionWithRetry, 'typeName' | 'parser' | 'partialResultCleaner'>,
{ render }: AI.ComponentContext
) {
if (schema) {
try {
return yield* render(<JsonChatCompletionFunctionCall schema={schema} {...props} />);
} catch (e: any) {
if (e.code !== ErrorCode.ChatModelDoesNotSupportFunctions) {
throw e;
}
try {
return yield* render(<JsonChatCompletionFunctionCall schema={schema ?? z.object({}).nonstrict()} {...props} />);
} catch (e: any) {
if (e.code !== ErrorCode.ChatModelDoesNotSupportFunctions) {
throw e;
}
}
return yield* render(
<ObjectCompletionWithRetry
{...props}
typeName="JSON"
parser={JSON.parse}
// TODO: can we remove .default?
partialResultCleaner={untruncateJson.default}
partialResultCleaner={patchedUntruncateJson}
/>
);
}
Expand Down Expand Up @@ -148,6 +152,16 @@ export async function* YamlChatCompletion(
);
}

export class CompletionError extends AIJSXError {
constructor(
message: string,
public readonly blame: ErrorBlame,
public readonly metadata: Jsonifiable & { output: string; validationError: string }
) {
super(message, ErrorCode.ModelOutputDidNotMatchConstraint, blame, metadata);
}
}

/**
* A {@link ChatCompletion} component that constrains the output to be a valid object format (e.g. JSON/YAML).
*
Expand All @@ -159,7 +173,7 @@ export async function* YamlChatCompletion(
*/
async function* OneShotObjectCompletion(
{ children, typeName, validators, example, schema, parser, partialResultCleaner, ...props }: TypedObjectCompletion,
{ render, logger }: AI.ComponentContext
{ render }: AI.ComponentContext
) {
// If a schema is provided, it is added to the list of validators as well as the prompt.
const validatorsAndSchema = schema ? [schema.parse, ...(validators ?? [])] : validators ?? [];
Expand All @@ -170,9 +184,9 @@ async function* OneShotObjectCompletion(
<SystemMessage>
Respond with a {typeName} object that encodes your response.
{schema
? `The ${typeName} object should match this JSON Schema: ${JSON.stringify(zodToJsonSchema(schema))}`
? `The ${typeName} object should match this JSON Schema: ${JSON.stringify(zodToJsonSchema(schema))}\n`
: ''}
{example ? `For example: \n${example}` : ''}
{example ? `For example: ${example}\n` : ''}
Respond with only the {typeName} object. Do not include any explanatory prose. Do not include ```
{typeName.toLowerCase()} ``` code blocks.
</SystemMessage>
Expand All @@ -191,11 +205,11 @@ async function* OneShotObjectCompletion(
}
} catch (e: any) {
if (partial.done) {
logger.warn(
{ output: partial.value, cleaned: partialResultCleaner ? str : undefined, errorMessage: e.message },
"ObjectCompletion failed. The final result either didn't parse or didn't validate."
);
throw e;
throw new CompletionError(`The model did not produce a valid ${typeName} object`, 'runtime', {
typeName,
output: partial.value,
validationError: e.message,
});
}
continue;
}
Expand Down Expand Up @@ -230,7 +244,8 @@ async function* ObjectCompletionWithRetry(
output = yield* render(childrenWithCompletion);
return output;
} catch (e: any) {
validationError = e.message;
validationError = e.metadata.validationError;
output = e.metadata.output;
}

logger.debug({ atempt: 1, expectedFormat: props.typeName, output }, `Output did not validate to ${props.typeName}.`);
Expand All @@ -244,9 +259,10 @@ async function* ObjectCompletionWithRetry(
<AssistantMessage>{output}</AssistantMessage>
<UserMessage>
Try again. Here's the validation error when trying to parse the output as {props.typeName}:{'\n'}
```log filename="error.log"{'\n'}
{validationError}
{'\n'}
You must reformat the string to be a valid {props.typeName} object, but you must keep the same data.
{'\n```\n'}
You must reformat your previous output to be a valid {props.typeName} object, but you must keep the same data.
</UserMessage>
</OneShotObjectCompletion>
);
Expand All @@ -255,7 +271,8 @@ async function* ObjectCompletionWithRetry(
output = yield* render(completionRetry);
return output;
} catch (e: any) {
validationError = e.message;
validationError = e.metadata.validationError;
output = e.metadata.output;
}

logger.debug(
Expand All @@ -264,14 +281,14 @@ async function* ObjectCompletionWithRetry(
);
}

throw new AIJSXError(
throw new CompletionError(
`The model did not produce a valid ${props.typeName} object, even after ${retries} attempts.`,
ErrorCode.ModelOutputDidNotMatchConstraint,
'runtime',
{
typeName: props.typeName,
retries,
output,
validationError,
}
);
}
Expand Down
10 changes: 7 additions & 3 deletions packages/ai-jsx/src/core/completion.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,8 @@ export interface FunctionDefinition {
/**
* This function creates a [JSON Schema](https://json-schema.org/) object to describe
* parameters for a {@link FunctionDefinition}.
* The parameters can be described either using a record of parameter names to
* {@link PlainFunctionParameter} objects, or using a {@link z.ZodObject} schema object.
*
* @note If using a Zod schema, the top-level schema must be an object as per OpenAI specifications.
* See {@link FunctionParameters} for more information on what parameters are supported.
*/
export function getParametersSchema(parameters: FunctionParameters) {
if (parameters instanceof z.ZodObject) {
Expand Down Expand Up @@ -90,6 +88,12 @@ export interface PlainFunctionParameter {
*
* @note If using a Zod schema, the top-level schema must be an object as per OpenAI specifications:
* https://platform.openai.com/docs/api-reference/chat/create#chat/create-parameters
*
* For example, to describe a list of strings, the following is not accepted:
* `const schema: z.Schema = z.array(z.string())`
*
* Instead, you can wrap it in an object like so:
* `const schema: z.ZodObject = z.object({ arr: z.array(z.string()) })`
*/
export type FunctionParameters = Record<string, PlainFunctionParameter> | z.ZodObject<any>;

Expand Down
1 change: 0 additions & 1 deletion packages/ai-jsx/src/core/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ export enum ErrorCode {
AnthropicAPIError = 1021,
ChatModelDoesNotSupportFunctions = 1022,
ChatCompletionBadInput = 1023,
InvalidParamSchemaType = 1024,

ModelOutputDidNotMatchConstraint = 2000,

Expand Down
5 changes: 2 additions & 3 deletions packages/ai-jsx/src/lib/openai.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,8 @@ import GPT3Tokenizer from 'gpt3-tokenizer';
import { Merge, MergeExclusive } from 'type-fest';
import { Logger } from '../core/log.js';
import { HttpError, AIJSXError, ErrorCode } from '../core/errors.js';
import { getEnvVar } from './util.js';
import { getEnvVar, patchedUntruncateJson } from './util.js';
import { ChatOrCompletionModelOrBoth } from './model.js';
import untruncateJson from 'untruncate-json';

// https://platform.openai.com/docs/models/model-endpoint-compatibility
type ValidCompletionModel =
Expand Down Expand Up @@ -437,7 +436,7 @@ export async function* OpenAIChatModel(
if (props.experimental_streamFunctionCallOnly) {
yield JSON.stringify({
...currentMessage.function_call,
arguments: JSON.parse(untruncateJson.default(currentMessage.function_call.arguments || '{}')),
arguments: JSON.parse(patchedUntruncateJson(currentMessage.function_call.arguments || '{}')),
});
}
}
Expand Down
7 changes: 7 additions & 0 deletions packages/ai-jsx/src/lib/util.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import untruncateJson from 'untruncate-json';
import { AIJSXError } from '../core/errors.js';

/** @hidden */
Expand All @@ -14,3 +15,9 @@ export function getEnvVar(name: string, shouldThrow: boolean = true) {
}
return result;
}

/**
* There's an ESM issue with untruncate-json, so we need to do this to support running on both client & server.
*/
/** @hidden */
export const patchedUntruncateJson = 'default' in untruncateJson ? untruncateJson.default : untruncateJson;

0 comments on commit 3c4d652

Please sign in to comment.