From f44a6da5c6fb415918a55afd67daabe067f6a2eb Mon Sep 17 00:00:00 2001 From: Sjors Gielen Date: Fri, 10 Dec 2021 11:40:49 +0100 Subject: [PATCH] Only change codegen if useOptionals is messages/all. --- integration/use-optionals-all/test.bin | Bin 2755 -> 2755 bytes integration/use-optionals-all/test.ts | 38 ++++++----- src/main.ts | 90 +++++++++++++------------ src/types.ts | 39 ++++++++--- 4 files changed, 101 insertions(+), 66 deletions(-) diff --git a/integration/use-optionals-all/test.bin b/integration/use-optionals-all/test.bin index 7a0ae20f64d7f534ae8ab06ae6f9c8227c265050..11c02597f2d0e69c9ad97bc5c7c0eff732c3b4ee 100644 GIT binary patch delta 29 kcmX>sdRSD1i>oBHxJ0j@D8D3Mii3k$Kv;rtqsV$L0D~I`;s5{u delta 29 kcmX>sdRSD1i>oBHxJ0j@D8D3Mii3k$fJcI5qsV$L0D}1k-T(jq diff --git a/integration/use-optionals-all/test.ts b/integration/use-optionals-all/test.ts index e2f8657d1..29820e5ab 100644 --- a/integration/use-optionals-all/test.ts +++ b/integration/use-optionals-all/test.ts @@ -44,7 +44,7 @@ export function stateEnumToJSON(object: StateEnum): string { export interface OptionalsTest { id?: number; - child?: Child | undefined; + child?: Child; state?: StateEnum; long?: number; truth?: boolean; @@ -89,7 +89,7 @@ const baseOptionalsTest: object = { export const OptionalsTest = { encode(message: OptionalsTest, writer: Writer = Writer.create()): Writer { - if (!!message.id) { + if (message.id !== undefined && message.id !== 0) { writer.uint32(8).int32(message.id); } if (message.child !== undefined) { @@ -98,7 +98,7 @@ export const OptionalsTest = { if (message.state !== undefined && message.state !== 0) { writer.uint32(24).int32(message.state); } - if (!!message.long) { + if (message.long !== undefined && message.long !== 0) { writer.uint32(32).int64(message.long); } if (message.truth === true) { @@ -402,7 +402,7 @@ export const OptionalsTest = { return obj; }, - fromPartial(object: DeepPartial): OptionalsTest { + fromPartial, I>>(object: I): OptionalsTest { const message = { ...baseOptionalsTest } as OptionalsTest; message.id = object.id ?? 0; message.child = object.child !== undefined && object.child !== null ? Child.fromPartial(object.child) : undefined; @@ -411,13 +411,13 @@ export const OptionalsTest = { message.truth = object.truth ?? false; message.description = object.description ?? ''; message.data = object.data ?? new Uint8Array(); - message.repId = (object.repId ?? []).map((e) => e); - message.repChild = (object.repChild ?? []).map((e) => Child.fromPartial(e)); - message.repState = (object.repState ?? []).map((e) => e); - message.repLong = (object.repLong ?? []).map((e) => e); - message.repTruth = (object.repTruth ?? []).map((e) => e); - message.repDescription = (object.repDescription ?? []).map((e) => e); - message.repData = (object.repData ?? []).map((e) => e); + message.repId = object.repId?.map((e) => e) || []; + message.repChild = object.repChild?.map((e) => Child.fromPartial(e)) || []; + message.repState = object.repState?.map((e) => e) || []; + message.repLong = object.repLong?.map((e) => e) || []; + message.repTruth = object.repTruth?.map((e) => e) || []; + message.repDescription = object.repDescription?.map((e) => e) || []; + message.repData = object.repData?.map((e) => e) || []; message.optId = object.optId ?? undefined; message.optChild = object.optChild !== undefined && object.optChild !== null ? Child.fromPartial(object.optChild) : undefined; @@ -443,10 +443,10 @@ const baseOptionalsTest_TranslationsEntry: object = { key: '', value: '' }; export const OptionalsTest_TranslationsEntry = { encode(message: OptionalsTest_TranslationsEntry, writer: Writer = Writer.create()): Writer { - if (message.key !== undefined && message.key !== '') { + if (message.key !== '') { writer.uint32(10).string(message.key); } - if (message.value !== undefined && message.value !== '') { + if (message.value !== '') { writer.uint32(18).string(message.value); } return writer; @@ -487,7 +487,9 @@ export const OptionalsTest_TranslationsEntry = { return obj; }, - fromPartial(object: DeepPartial): OptionalsTest_TranslationsEntry { + fromPartial, I>>( + object: I + ): OptionalsTest_TranslationsEntry { const message = { ...baseOptionalsTest_TranslationsEntry } as OptionalsTest_TranslationsEntry; message.key = object.key ?? ''; message.value = object.value ?? ''; @@ -527,7 +529,7 @@ export const Child = { return obj; }, - fromPartial(_: DeepPartial): Child { + fromPartial, I>>(_: I): Child { const message = { ...baseChild } as Child; return message; }, @@ -566,6 +568,7 @@ function base64FromBytes(arr: Uint8Array): string { } type Builtin = Date | Function | Uint8Array | string | number | boolean | undefined; + export type DeepPartial = T extends Builtin ? T : T extends Array @@ -576,6 +579,11 @@ export type DeepPartial = T extends Builtin ? { [K in keyof T]?: DeepPartial } : Partial; +type KeysOfUnion = T extends T ? keyof T : never; +export type Exact = P extends Builtin + ? P + : P & { [K in keyof P]: Exact } & Record>, never>; + function longToNumber(long: Long): number { if (long.gt(Number.MAX_SAFE_INTEGER)) { throw new globalThis.Error('Value is larger than Number.MAX_SAFE_INTEGER'); diff --git a/src/main.ts b/src/main.ts index acf4095f5..8bca547cd 100644 --- a/src/main.ts +++ b/src/main.ts @@ -24,6 +24,7 @@ import { isLongValueType, isMapType, isMessage, + isOptionalProperty, isPrimitive, isRepeated, isScalar, @@ -555,24 +556,6 @@ function makeTimestampMethods(options: Options, longs: ReturnType { + Object.entries(message.${fieldName}${optionalAlternative}).forEach(([key, value]) => { ${entryWriteSnippet} }); `); } else if (packedType(field.type) === undefined) { - chunks.push(code` - if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { - for (const v of message.${fieldName}) { - ${writeSnippet('v!')}; - } + const listWriteSnippet = code` + for (const v of message.${fieldName}) { + ${writeSnippet('v!')}; } - `); + `; + if (isOptional) { + chunks.push(code` + if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { + ${listWriteSnippet} + } + `); + } else { + chunks.push(listWriteSnippet) + } } else if (isEnum(field) && options.stringEnums) { // This is a lot like the `else` clause, but we wrap `fooToNumber` around it. // Ideally we'd reuse `writeSnippet` here, but `writeSnippet` has the `writer.uint32(tag)` @@ -906,27 +898,41 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP // (i.e. just one tag and multiple values). const tag = ((field.number << 3) | 2) >>> 0; const toNumber = getEnumMethod(ctx, field.typeName, 'ToNumber'); - chunks.push(code` - if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { - writer.uint32(${tag}).fork(); - for (const v of message.${fieldName}) { - writer.${toReaderCall(field)}(${toNumber}(v)); - } - writer.ldelim(); + const listWriteSnippet = code` + writer.uint32(${tag}).fork(); + for (const v of message.${fieldName}) { + writer.${toReaderCall(field)}(${toNumber}(v)); } - `); + writer.ldelim(); + `; + if (isOptional) { + chunks.push(code` + if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { + ${listWriteSnippet} + } + `); + } else { + chunks.push(listWriteSnippet) + } } else { // Ideally we'd reuse `writeSnippet` but it has tagging embedded inside of it. const tag = ((field.number << 3) | 2) >>> 0; - chunks.push(code` - if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { - writer.uint32(${tag}).fork(); - for (const v of message.${fieldName}) { - writer.${toReaderCall(field)}(v); - } - writer.ldelim(); + const listWriteSnippet = code` + writer.uint32(${tag}).fork(); + for (const v of message.${fieldName}) { + writer.${toReaderCall(field)}(v); } - `); + writer.ldelim(); + `; + if (isOptional) { + chunks.push(code` + if (message.${fieldName} !== undefined && message.${fieldName}.length !== 0) { + ${listWriteSnippet} + } + `); + } else { + chunks.push(listWriteSnippet) + } } } else if (isWithinOneOfThatShouldBeUnion(options, field)) { let oneofName = maybeSnakeToCamel(messageDesc.oneofDecl[field.oneofIndex].name, options); @@ -950,7 +956,7 @@ function generateEncode(ctx: Context, fullName: string, messageDesc: DescriptorP `); } else if (isScalar(field) || isEnum(field)) { chunks.push(code` - if (${notDefaultCheck(ctx, field, `message.${fieldName}`)}) { + if (${notDefaultCheck(ctx, field, messageDesc.options, `message.${fieldName}`)}) { ${writeSnippet(`message.${fieldName}`)}; } `); diff --git a/src/types.ts b/src/types.ts index f263e1e1a..dc27ebe01 100644 --- a/src/types.ts +++ b/src/types.ts @@ -6,6 +6,7 @@ import { FieldDescriptorProto_Label, FieldDescriptorProto_Type, FileDescriptorProto, + MessageOptions, MethodDescriptorProto, ServiceDescriptorProto, } from 'ts-proto-descriptors'; @@ -231,8 +232,10 @@ export function defaultValue(ctx: Context, field: FieldDescriptorProto): any { } /** Creates code that checks that the field is not the default value. Supports scalars and enums. */ -export function notDefaultCheck(ctx: Context, field: FieldDescriptorProto, place: string): Code { +export function notDefaultCheck(ctx: Context, field: FieldDescriptorProto, messageOptions: MessageOptions | undefined, place: string): Code { const { typeMap, options } = ctx; + const isOptional = isOptionalProperty(field, messageOptions, options); + const maybeNotUndefinedAnd = isOptional ? `${place} !== undefined && ` : ""; switch (field.type) { case FieldDescriptorProto_Type.TYPE_DOUBLE: case FieldDescriptorProto_Type.TYPE_FLOAT: @@ -241,7 +244,7 @@ export function notDefaultCheck(ctx: Context, field: FieldDescriptorProto, place case FieldDescriptorProto_Type.TYPE_SINT32: case FieldDescriptorProto_Type.TYPE_FIXED32: case FieldDescriptorProto_Type.TYPE_SFIXED32: - return code`!!${place}`; + return code`${maybeNotUndefinedAnd} ${place} !== 0`; case FieldDescriptorProto_Type.TYPE_ENUM: // proto3 enforces enums starting at 0, however proto2 does not, so we have // to probe and see if zero is an allowed value. If it's not, pick the first one. @@ -251,9 +254,9 @@ export function notDefaultCheck(ctx: Context, field: FieldDescriptorProto, place const zerothValue = enumProto.value.find((v) => v.number === 0) || enumProto.value[0]; if (options.stringEnums) { const enumType = messageToTypeName(ctx, field.typeName); - return code`${place} !== undefined && ${place} !== ${enumType}.${zerothValue.name}`; + return code`${maybeNotUndefinedAnd} ${place} !== ${enumType}.${zerothValue.name}`; } else { - return code`${place} !== undefined && ${place} !== ${zerothValue.number}`; + return code`${maybeNotUndefinedAnd} ${place} !== ${zerothValue.number}`; } case FieldDescriptorProto_Type.TYPE_UINT64: case FieldDescriptorProto_Type.TYPE_FIXED64: @@ -261,18 +264,18 @@ export function notDefaultCheck(ctx: Context, field: FieldDescriptorProto, place case FieldDescriptorProto_Type.TYPE_SINT64: case FieldDescriptorProto_Type.TYPE_SFIXED64: if (options.forceLong === LongOption.LONG) { - return code`${place} !== undefined && !${place}.isZero()`; + return code`${maybeNotUndefinedAnd} !${place}.isZero()`; } else if (options.forceLong === LongOption.STRING) { - return code`${place} !== undefined && ${place} !== "0"`; + return code`${maybeNotUndefinedAnd} ${place} !== "0"`; } else { - return code`!!${place}`; + return code`${maybeNotUndefinedAnd} ${place} !== 0`; } case FieldDescriptorProto_Type.TYPE_BOOL: return code`${place} === true`; case FieldDescriptorProto_Type.TYPE_STRING: - return code`${place} !== undefined && ${place} !== ""`; + return code`${maybeNotUndefinedAnd} ${place} !== ""`; case FieldDescriptorProto_Type.TYPE_BYTES: - return code`${place} !== undefined && ${place}.length !== 0`; + return code`${maybeNotUndefinedAnd} ${place}.length !== 0`; default: throw new Error('Not implemented for the given type.'); } @@ -325,6 +328,24 @@ export function isScalar(field: FieldDescriptorProto): boolean { return scalarTypes.includes(field.type); } +// When useOptionals='messages', non-scalar fields are translated into optional +// properties. When useOptionals='all', all fields are translated into +// optional properties, with the exception of map Entry key/values, which must +// always be present. +export function isOptionalProperty( + field: FieldDescriptorProto, + messageOptions: MessageOptions | undefined, + options: Options +): boolean { + const optionalMessages = options.useOptionals === true || options.useOptionals === 'messages' || options.useOptionals === 'all'; + const optionalAll = options.useOptionals === 'all'; + return ( + (optionalMessages && isMessage(field) && !isRepeated(field)) || + (optionalAll && !messageOptions?.mapEntry) || + field.proto3Optional + ); +} + /** This includes all scalars, enums and the [groups type](https://developers.google.com/protocol-buffers/docs/reference/java/com/google/protobuf/DescriptorProtos.FieldDescriptorProto.Type.html#TYPE_GROUP) */ export function isPrimitive(field: FieldDescriptorProto): boolean { return !isMessage(field);