From 450c9ad4d1ee0dfe5a0c06b02128b3c67ce87646 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Tue, 8 Nov 2022 09:30:37 -0500 Subject: [PATCH] Add Protobuf decoder --- lib/trino-record-decoder/pom.xml | 31 ++ .../java/io/trino/decoder/DecoderModule.java | 2 + .../protobuf/DynamicMessageProvider.java | 28 ++ .../FixedSchemaDynamicMessageProvider.java | 67 ++++ .../protobuf/ProtobufColumnDecoder.java | 147 ++++++++ .../protobuf/ProtobufDecoderModule.java | 32 ++ .../decoder/protobuf/ProtobufErrorCode.java | 43 +++ .../decoder/protobuf/ProtobufRowDecoder.java | 55 +++ .../protobuf/ProtobufRowDecoderFactory.java | 49 +++ .../trino/decoder/protobuf/ProtobufUtils.java | 272 +++++++++++++++ .../protobuf/ProtobufValueProvider.java | 322 ++++++++++++++++++ .../protobuf/ProtobufDataProviders.java | 71 ++++ .../decoder/protobuf/TestProtobufDecoder.java | 280 +++++++++++++++ .../decoder/protobuf/all_datatypes.proto | 19 ++ .../protobuf/structural_datatypes.proto | 31 ++ plugin/trino-kafka/pom.xml | 15 + plugin/trino-kinesis/pom.xml | 10 + plugin/trino-redis/pom.xml | 15 + pom.xml | 34 ++ 19 files changed, 1523 insertions(+) create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java create mode 100644 lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java create mode 100644 lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java create mode 100644 lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto create mode 100644 lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto diff --git a/lib/trino-record-decoder/pom.xml b/lib/trino-record-decoder/pom.xml index cf29152408a8..996425a5eeb4 100644 --- a/lib/trino-record-decoder/pom.xml +++ b/lib/trino-record-decoder/pom.xml @@ -33,6 +33,12 @@ jackson-databind + + com.google.code.findbugs + jsr305 + true + + com.google.guava guava @@ -43,6 +49,16 @@ guice + + com.google.protobuf + protobuf-java + + + + com.squareup.wire + wire-schema + + javax.inject javax.inject @@ -107,4 +123,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java index 75c4c519c72a..d5702fea6031 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java @@ -23,6 +23,7 @@ import io.trino.decoder.dummy.DummyRowDecoderFactory; import io.trino.decoder.json.JsonRowDecoder; import io.trino.decoder.json.JsonRowDecoderFactory; +import io.trino.decoder.protobuf.ProtobufDecoderModule; import io.trino.decoder.raw.RawRowDecoder; import io.trino.decoder.raw.RawRowDecoderFactory; @@ -43,6 +44,7 @@ public void configure(Binder binder) decoderFactoriesByName.addBinding(JsonRowDecoder.NAME).to(JsonRowDecoderFactory.class).in(SINGLETON); decoderFactoriesByName.addBinding(RawRowDecoder.NAME).to(RawRowDecoderFactory.class).in(SINGLETON); binder.install(new AvroDecoderModule()); + binder.install(new ProtobufDecoderModule()); binder.bind(DispatchingRowDecoderFactory.class).in(SINGLETON); } } diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java new file mode 100644 index 000000000000..bac8ad6761ee --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.DynamicMessage; + +import java.util.Optional; + +public interface DynamicMessageProvider +{ + DynamicMessage parseDynamicMessage(byte[] data); + + interface Factory + { + DynamicMessageProvider create(Optional protoFile); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java new file mode 100644 index 000000000000..61b6c92305a6 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import io.trino.spi.TrinoException; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FixedSchemaDynamicMessageProvider + implements DynamicMessageProvider +{ + private final Descriptor descriptor; + + public FixedSchemaDynamicMessageProvider(Descriptor descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + @Override + public DynamicMessage parseDynamicMessage(byte[] data) + { + try { + return DynamicMessage.parseFrom(descriptor, data); + } + catch (InvalidProtocolBufferException e) { + throw new TrinoException(ProtobufErrorCode.INVALID_PROTOBUF_MESSAGE, "Decoding Protobuf record failed.", e); + } + } + + public static class Factory + implements DynamicMessageProvider.Factory + { + @Override + public DynamicMessageProvider create(Optional protoFile) + { + checkState(protoFile.isPresent(), "proto file is missing"); + try { + Descriptor descriptor = ProtobufUtils.getFileDescriptor(protoFile.orElseThrow()).findMessageTypeByName(DEFAULT_MESSAGE); + checkState(descriptor != null, format("Message %s not found", DEFAULT_MESSAGE)); + return new FixedSchemaDynamicMessageProvider(descriptor); + } + catch (DescriptorValidationException descriptorValidationException) { + throw new TrinoException(ProtobufErrorCode.INVALID_PROTO_FILE, "Unable to parse protobuf schema", descriptorValidationException); + } + } + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java new file mode 100644 index 000000000000..56118db9f00e --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.spi.TrinoException; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static java.util.Objects.requireNonNull; + +public class ProtobufColumnDecoder +{ + private static final Set SUPPORTED_PRIMITIVE_TYPES = ImmutableSet.of( + BooleanType.BOOLEAN, + TinyintType.TINYINT, + SmallintType.SMALLINT, + IntegerType.INTEGER, + BigintType.BIGINT, + RealType.REAL, + DoubleType.DOUBLE, + VarbinaryType.VARBINARY); + + private final Type columnType; + private final String columnMapping; + private final String columnName; + + public ProtobufColumnDecoder(DecoderColumnHandle columnHandle) + { + try { + requireNonNull(columnHandle, "columnHandle is null"); + this.columnType = columnHandle.getType(); + this.columnMapping = columnHandle.getMapping(); + this.columnName = columnHandle.getName(); + checkArgument(!columnHandle.isInternal(), "unexpected internal column '%s'", columnName); + checkArgument(columnHandle.getFormatHint() == null, "unexpected format hint '%s' defined for column '%s'", columnHandle.getFormatHint(), columnName); + checkArgument(columnHandle.getDataFormat() == null, "unexpected data format '%s' defined for column '%s'", columnHandle.getDataFormat(), columnName); + checkArgument(columnHandle.getMapping() != null, "mapping not defined for column '%s'", columnName); + + checkArgument(isSupportedType(columnType), "Unsupported column type '%s' for column '%s'", columnType, columnName); + } + catch (IllegalArgumentException e) { + throw new TrinoException(GENERIC_USER_ERROR, e); + } + } + + private static boolean isSupportedType(Type type) + { + if (isSupportedPrimitive(type)) { + return true; + } + + if (type instanceof ArrayType) { + checkArgument(type.getTypeParameters().size() == 1, "expecting exactly one type parameter for array"); + return isSupportedType(type.getTypeParameters().get(0)); + } + + if (type instanceof MapType) { + List typeParameters = type.getTypeParameters(); + checkArgument(typeParameters.size() == 2, "expecting exactly two type parameters for map"); + return isSupportedType(typeParameters.get(0)) && isSupportedType(type.getTypeParameters().get(1)); + } + + if (type instanceof RowType) { + for (Type fieldType : type.getTypeParameters()) { + if (!isSupportedType(fieldType)) { + return false; + } + } + return true; + } + return false; + } + + private static boolean isSupportedPrimitive(Type type) + { + return (type instanceof TimestampType && ((TimestampType) type).isShort()) || + type instanceof VarcharType || + SUPPORTED_PRIMITIVE_TYPES.contains(type); + } + + public FieldValueProvider decodeField(DynamicMessage dynamicMessage) + { + return new ProtobufValueProvider(locateField(dynamicMessage, columnMapping), columnType, columnName); + } + + @Nullable + private static Object locateField(DynamicMessage message, String columnMapping) + { + Object value = message; + Optional valueDescriptor = Optional.of(message.getDescriptorForType()); + for (String pathElement : Splitter.on('/').omitEmptyStrings().split(columnMapping)) { + if (valueDescriptor.filter(descriptor -> descriptor.findFieldByName(pathElement) != null).isEmpty()) { + return null; + } + FieldDescriptor fieldDescriptor = valueDescriptor.get().findFieldByName(pathElement); + value = ((DynamicMessage) value).getField(fieldDescriptor); + valueDescriptor = getDescriptor(fieldDescriptor); + } + return value; + } + + private static Optional getDescriptor(FieldDescriptor fieldDescriptor) + { + if (fieldDescriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + return Optional.of(fieldDescriptor.getMessageType()); + } + return Optional.empty(); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java new file mode 100644 index 000000000000..e924905128aa --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.inject.Binder; +import com.google.inject.Module; +import io.trino.decoder.RowDecoderFactory; + +import static com.google.inject.Scopes.SINGLETON; +import static com.google.inject.multibindings.MapBinder.newMapBinder; + +public class ProtobufDecoderModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(DynamicMessageProvider.Factory.class).to(FixedSchemaDynamicMessageProvider.Factory.class).in(SINGLETON); + newMapBinder(binder, String.class, RowDecoderFactory.class).addBinding(ProtobufRowDecoder.NAME).to(ProtobufRowDecoderFactory.class).in(SINGLETON); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java new file mode 100644 index 000000000000..7ca0076099ca --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import io.trino.spi.ErrorCode; +import io.trino.spi.ErrorCodeSupplier; +import io.trino.spi.ErrorType; + +import static io.trino.spi.ErrorType.EXTERNAL; + +public enum ProtobufErrorCode + implements ErrorCodeSupplier +{ + INVALID_PROTO_FILE(0, EXTERNAL), + MESSAGE_NOT_FOUND(1, EXTERNAL), + INVALID_PROTOBUF_MESSAGE(2, EXTERNAL), + INVALID_TIMESTAMP(3, EXTERNAL), + /**/; + + private final ErrorCode errorCode; + + ProtobufErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0606_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java new file mode 100644 index 000000000000..441e040ae1f7 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.DynamicMessage; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.decoder.RowDecoder; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.UnaryOperator.identity; + +public class ProtobufRowDecoder + implements RowDecoder +{ + public static final String NAME = "protobuf"; + + private final DynamicMessageProvider dynamicMessageProvider; + private final Map columnDecoders; + + public ProtobufRowDecoder(DynamicMessageProvider dynamicMessageProvider, Set columns) + { + this.dynamicMessageProvider = requireNonNull(dynamicMessageProvider, "dynamicMessageSupplier is null"); + this.columnDecoders = columns.stream() + .collect(toImmutableMap( + identity(), + ProtobufColumnDecoder::new)); + } + + @Override + public Optional> decodeRow(byte[] data) + { + DynamicMessage message = dynamicMessageProvider.parseDynamicMessage(data); + return Optional.of(columnDecoders.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> entry.getValue().decodeField(message)))); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java new file mode 100644 index 000000000000..52778be4f1b1 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.RowDecoder; +import io.trino.decoder.RowDecoderFactory; +import io.trino.decoder.protobuf.DynamicMessageProvider.Factory; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class ProtobufRowDecoderFactory + implements RowDecoderFactory +{ + public static final String DEFAULT_MESSAGE = "schema"; + + private final Factory dynamicMessageProviderFactory; + + @Inject + public ProtobufRowDecoderFactory(Factory dynamicMessageProviderFactory) + { + this.dynamicMessageProviderFactory = requireNonNull(dynamicMessageProviderFactory, "dynamicMessageProviderFactory is null"); + } + + @Override + public RowDecoder create(Map decoderParams, Set columns) + { + return new ProtobufRowDecoder( + dynamicMessageProviderFactory.create(Optional.ofNullable(decoderParams.get("dataSchema"))), + columns); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java new file mode 100644 index 000000000000..02082c37c612 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.io.Resources; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.EnumDescriptorProto; +import com.google.protobuf.DescriptorProtos.EnumValueDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.DescriptorProtos.MessageOptions; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.squareup.wire.schema.Field; +import com.squareup.wire.schema.Location; +import com.squareup.wire.schema.ProtoType; +import com.squareup.wire.schema.internal.parser.EnumConstantElement; +import com.squareup.wire.schema.internal.parser.EnumElement; +import com.squareup.wire.schema.internal.parser.FieldElement; +import com.squareup.wire.schema.internal.parser.MessageElement; +import com.squareup.wire.schema.internal.parser.ProtoFileElement; +import com.squareup.wire.schema.internal.parser.ProtoParser; +import com.squareup.wire.schema.internal.parser.TypeElement; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE; +import static com.google.common.base.CaseFormat.UPPER_CAMEL; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_ENUM; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_GROUP; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_MESSAGE; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT64; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +public final class ProtobufUtils +{ + private ProtobufUtils() + {} + + public static FileDescriptor getFileDescriptor(String protoFile) + throws DescriptorValidationException + { + ProtoFileElement protoFileElement = ProtoParser.Companion.parse(Location.get(""), protoFile); + return getFileDescriptor(Optional.empty(), protoFileElement); + } + + public static FileDescriptor getFileDescriptor(Optional fileName, ProtoFileElement protoFileElement) + throws DescriptorValidationException + { + FileDescriptor[] dependencies = new FileDescriptor[protoFileElement.getImports().size()]; + Set definedMessages = new HashSet<>(); + int index = 0; + for (String importStatement : protoFileElement.getImports()) { + try { + FileDescriptor fileDescriptor = getFileDescriptor(getProtoFile(importStatement)); + fileDescriptor.getMessageTypes().stream() + .map(Descriptor::getFullName) + .forEach(definedMessages::add); + dependencies[index] = fileDescriptor; + index++; + } + catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + FileDescriptorProto.Builder builder = FileDescriptorProto.newBuilder(); + + if (protoFileElement.getSyntax() != null) { + builder.setSyntax(protoFileElement.getSyntax().name()); + } + + fileName.ifPresent(builder::setName); + builder.addAllDependency(protoFileElement.getImports()); + + if (protoFileElement.getPackageName() != null) { + builder.setPackage(protoFileElement.getPackageName()); + } + + for (TypeElement element : protoFileElement.getTypes()) { + if (element instanceof MessageElement) { + builder.addMessageType(processMessage((MessageElement) element, definedMessages)); + definedMessages.add(element.getName()); + } + if (element instanceof EnumElement) { + builder.addEnumType(processEnum((EnumElement) element)); + } + } + + return FileDescriptor.buildFrom(builder.build(), dependencies); + } + + public static String getProtoFile(String filePath) + throws IOException + { + return Resources.toString(Resources.getResource(ProtobufUtils.class, "/" + filePath), UTF_8); + } + + private static DescriptorProto processMessage(MessageElement message, Set globallyDefinedMessages) + { + DescriptorProto.Builder builder = DescriptorProto.newBuilder(); + builder.setName(message.getName()); + Set definedMessages = new HashSet<>(globallyDefinedMessages); + for (TypeElement typeElement : message.getNestedTypes()) { + if (typeElement instanceof EnumElement) { + builder.addEnumType(processEnum((EnumElement) typeElement)); + } + if (typeElement instanceof MessageElement) { + builder.addNestedType(processMessage((MessageElement) typeElement, definedMessages)); + definedMessages.add(typeElement.getName()); + } + } + for (FieldElement field : message.getFields()) { + ProtoType protoType = ProtoType.get(field.getType()); + FieldDescriptorProto.Builder fieldDescriptor = FieldDescriptorProto.newBuilder() + .setName(field.getName()) + .setNumber(field.getTag()); + if (protoType.isMap()) { + requireNonNull(protoType.getKeyType(), "keyType is null"); + requireNonNull(protoType.getValueType(), "valueType is null"); + builder.addNestedType(DescriptorProto.newBuilder() + //First Name to Upper case + .setName(getNameForMapField(field.getName())) + .setOptions(MessageOptions.newBuilder().setMapEntry(true).build()) + .addField( + processType( + FieldDescriptorProto.newBuilder() + .setName("key") + .setNumber(1) + .setLabel(Label.LABEL_OPTIONAL), + protoType.getKeyType(), + definedMessages)) + .addField( + processType( + FieldDescriptorProto.newBuilder() + .setName("value") + .setNumber(2) + .setLabel(Label.LABEL_OPTIONAL), + protoType.getValueType(), + definedMessages)) + .build()); + // Handle for underscores and name + fieldDescriptor.setType(TYPE_MESSAGE) + .setLabel(Label.LABEL_REPEATED) + .setTypeName(getNameForMapField(field.getName())); + } + else { + processType(fieldDescriptor, protoType, definedMessages); + } + if (field.getLabel() != null && field.getLabel() != Field.Label.ONE_OF) { + fieldDescriptor.setLabel(getLabel(field.getLabel())); + } + if (field.getDefaultValue() != null) { + fieldDescriptor.setDefaultValue(field.getDefaultValue()); + } + builder.addField(fieldDescriptor.build()); + } + return builder.build(); + } + + private static EnumDescriptorProto processEnum(EnumElement enumElement) + { + EnumDescriptorProto.Builder enumBuilder = EnumDescriptorProto.newBuilder(); + enumBuilder.setName(enumElement.getName()); + for (EnumConstantElement enumConstant : enumElement.getConstants()) { + enumBuilder.addValue(EnumValueDescriptorProto.newBuilder() + .setName(enumConstant.getName()) + .setNumber(enumConstant.getTag()) + .build()); + } + return enumBuilder.build(); + } + + public static FieldDescriptorProto.Builder processType(FieldDescriptorProto.Builder builder, ProtoType type, Set messageNames) + { + switch (type.getSimpleName()) { + case "double" : + return builder.setType(TYPE_DOUBLE); + case "float" : + return builder.setType(TYPE_FLOAT); + case "int64" : + return builder.setType(TYPE_INT64); + case "uint64" : + return builder.setType(TYPE_UINT64); + case "int32" : + return builder.setType(TYPE_INT32); + case "fixed64" : + return builder.setType(TYPE_FIXED64); + case "fixed32" : + return builder.setType(TYPE_FIXED32); + case "bool" : + return builder.setType(TYPE_BOOL); + case "string" : + return builder.setType(TYPE_STRING); + case "group" : + return builder.setType(TYPE_GROUP); + case "bytes" : + return builder.setType(TYPE_BYTES); + case "uint32" : + return builder.setType(TYPE_UINT32); + case "sfixed32" : + return builder.setType(TYPE_SFIXED32); + case "sfixed64" : + return builder.setType(TYPE_SFIXED64); + case "sint32" : + return builder.setType(TYPE_SINT32); + case "sint64" : + return builder.setType(TYPE_SINT64); + default: { + builder.setTypeName(type.toString()); + if (messageNames.contains(type.toString())) { + builder.setType(TYPE_MESSAGE); + } + else { + builder.setType(TYPE_ENUM); + } + return builder; + } + } + } + + public static Label getLabel(Field.Label label) + { + switch (label) { + case OPTIONAL: + return Label.LABEL_OPTIONAL; + case REPEATED: + return Label.LABEL_REPEATED; + case REQUIRED: + return Label.LABEL_REQUIRED; + default: + throw new IllegalArgumentException("Unknown label"); + } + } + + private static String getNameForMapField(String fieldName) + { + return LOWER_UNDERSCORE.to(UPPER_CAMEL, fieldName) + "Entry"; + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java new file mode 100644 index 000000000000..a7f80a2d4d88 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java @@ -0,0 +1,322 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors.EnumValueDescriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.decoder.FieldValueProvider; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import javax.annotation.Nullable; + +import java.util.Collection; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.decoder.DecoderErrorCode.DECODER_CONVERSION_NOT_SUPPORTED; +import static io.trino.spi.type.TimestampType.MAX_SHORT_PRECISION; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.round; +import static io.trino.spi.type.Timestamps.roundDiv; +import static io.trino.spi.type.Varchars.truncateToLength; +import static java.lang.Float.floatToIntBits; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ProtobufValueProvider + extends FieldValueProvider +{ + @Nullable + private final Object value; + private final Type columnType; + private final String columnName; + + public ProtobufValueProvider(@Nullable Object value, Type columnType, String columnName) + { + this.value = value; + this.columnType = requireNonNull(columnType, "columnType is null"); + this.columnName = requireNonNull(columnName, "columnName is null"); + } + + @Override + public boolean isNull() + { + return value == null; + } + + @Override + public double getDouble() + { + requireNonNull(value, "value is null"); + if (value instanceof Double || value instanceof Float) { + return ((Number) value).doubleValue(); + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public boolean getBoolean() + { + requireNonNull(value, "value is null"); + if (value instanceof Boolean) { + return (Boolean) value; + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public long getLong() + { + requireNonNull(value, "value is null"); + if (value instanceof Long || value instanceof Integer) { + return ((Number) value).longValue(); + } + if (value instanceof Float) { + return Float.floatToIntBits((Float) value); + } + if (value instanceof DynamicMessage) { + checkArgument(columnType instanceof TimestampType, "type should be an instance of Timestamp"); + return parseTimestamp(((TimestampType) columnType).getPrecision(), (DynamicMessage) value); + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public Slice getSlice() + { + return getSlice(value, columnType, columnName); + } + + @Override + public Block getBlock() + { + return serializeObject(null, value, columnType, columnName); + } + + private static Slice getSlice(Object value, Type type, String columnName) + { + requireNonNull(value, "value is null"); + if ((type instanceof VarcharType && value instanceof CharSequence) || value instanceof EnumValueDescriptor) { + return truncateToLength(utf8Slice(value.toString()), type); + } + + if (type instanceof VarbinaryType && value instanceof ByteString) { + return Slices.wrappedBuffer(((ByteString) value).toByteArray()); + } + + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), type, columnName)); + } + + @Nullable + private static Block serializeObject(BlockBuilder builder, Object value, Type type, String columnName) + { + if (type instanceof ArrayType) { + return serializeList(builder, value, type, columnName); + } + if (type instanceof MapType) { + return serializeMap(builder, value, type, columnName); + } + if (type instanceof RowType) { + return serializeRow(builder, value, type, columnName); + } + + serializePrimitive(builder, value, type, columnName); + return null; + } + + @Nullable + private static Block serializeList(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parentBlockBuilder is null"); + parentBlockBuilder.appendNull(); + return null; + } + List list = (List) value; + List typeParameters = type.getTypeParameters(); + Type elementType = typeParameters.get(0); + + BlockBuilder blockBuilder = elementType.createBlockBuilder(null, list.size()); + for (Object element : list) { + serializeObject(blockBuilder, element, elementType, columnName); + } + if (parentBlockBuilder != null) { + type.writeObject(parentBlockBuilder, blockBuilder.build()); + return null; + } + return blockBuilder.build(); + } + + private static void serializePrimitive(BlockBuilder blockBuilder, @Nullable Object value, Type type, String columnName) + { + requireNonNull(blockBuilder, "parent blockBuilder is null"); + + if (value == null) { + blockBuilder.appendNull(); + return; + } + + if (type instanceof BooleanType) { + type.writeBoolean(blockBuilder, (Boolean) value); + return; + } + + if ((value instanceof Integer || value instanceof Long) && (type instanceof BigintType || type instanceof IntegerType || type instanceof SmallintType || type instanceof TinyintType)) { + type.writeLong(blockBuilder, ((Number) value).longValue()); + return; + } + + if (type instanceof DoubleType && value instanceof Double) { + type.writeDouble(blockBuilder, (Double) value); + return; + } + + if (type instanceof RealType && value instanceof Float) { + type.writeLong(blockBuilder, floatToIntBits((Float) value)); + return; + } + + if (type instanceof VarcharType || type instanceof VarbinaryType) { + type.writeSlice(blockBuilder, getSlice(value, type, columnName)); + return; + } + + if (type instanceof TimestampType && ((TimestampType) type).isShort()) { + checkArgument(value instanceof DynamicMessage, "value should be an instance of DynamicMessage"); + type.writeLong(blockBuilder, parseTimestamp(((TimestampType) type).getPrecision(), (DynamicMessage) value)); + return; + } + + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), type, columnName)); + } + + @Nullable + private static Block serializeMap(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parentBlockBuilder is null"); + parentBlockBuilder.appendNull(); + return null; + } + + Collection dynamicMessages = ((Collection) value).stream() + .map(DynamicMessage.class::cast) + .collect(toImmutableList()); + List typeParameters = type.getTypeParameters(); + Type keyType = typeParameters.get(0); + Type valueType = typeParameters.get(1); + + BlockBuilder blockBuilder; + if (parentBlockBuilder != null) { + blockBuilder = parentBlockBuilder; + } + else { + blockBuilder = type.createBlockBuilder(null, 1); + } + + BlockBuilder entryBuilder = blockBuilder.beginBlockEntry(); + for (DynamicMessage dynamicMessage : dynamicMessages) { + if (dynamicMessage.getField(dynamicMessage.getDescriptorForType().findFieldByNumber(1)) != null) { + serializeObject(entryBuilder, dynamicMessage.getField(getFieldDescriptor(dynamicMessage, 1)), keyType, columnName); + serializeObject(entryBuilder, dynamicMessage.getField(getFieldDescriptor(dynamicMessage, 2)), valueType, columnName); + } + } + blockBuilder.closeEntry(); + + if (parentBlockBuilder == null) { + return blockBuilder.getObject(0, Block.class); + } + return null; + } + + @Nullable + private static Block serializeRow(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parent block builder is null"); + parentBlockBuilder.appendNull(); + return null; + } + + BlockBuilder blockBuilder; + if (parentBlockBuilder != null) { + blockBuilder = parentBlockBuilder; + } + else { + blockBuilder = type.createBlockBuilder(null, 1); + } + BlockBuilder singleRowBuilder = blockBuilder.beginBlockEntry(); + DynamicMessage record = (DynamicMessage) value; + List fields = ((RowType) type).getFields(); + for (RowType.Field field : fields) { + checkState(field.getName().isPresent(), "field name not found"); + FieldDescriptor fieldDescriptor = getFieldDescriptor(record, field.getName().get()); + checkState(fieldDescriptor != null, format("Unknown Field %s", field.getName().get())); + serializeObject( + singleRowBuilder, + record.getField(fieldDescriptor), + field.getType(), + columnName); + } + blockBuilder.closeEntry(); + if (parentBlockBuilder == null) { + return blockBuilder.getObject(0, Block.class); + } + return null; + } + + private static long parseTimestamp(int precision, DynamicMessage timestamp) + { + long seconds = (Long) timestamp.getField(timestamp.getDescriptorForType().findFieldByName("seconds")); + int nanos = (Integer) timestamp.getField(timestamp.getDescriptorForType().findFieldByName("nanos")); + long micros = seconds * MICROSECONDS_PER_SECOND; + micros += roundDiv(nanos, NANOSECONDS_PER_MICROSECOND); + checkArgument(precision <= MAX_SHORT_PRECISION, "precision must be less than max short timestamp precision (" + MAX_SHORT_PRECISION + ")"); + return round(micros, MAX_SHORT_PRECISION - precision); + } + + private static FieldDescriptor getFieldDescriptor(DynamicMessage message, String name) + { + return message.getDescriptorForType().findFieldByName(name); + } + + private static FieldDescriptor getFieldDescriptor(DynamicMessage message, int index) + { + return message.getDescriptorForType().findFieldByNumber(index); + } +} diff --git a/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java new file mode 100644 index 000000000000..d9383833d32d --- /dev/null +++ b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import org.testng.annotations.DataProvider; + +import java.time.LocalDateTime; + +import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf; +import static java.lang.Math.PI; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.stream.Collectors.joining; +import static java.util.stream.IntStream.range; + +public class ProtobufDataProviders +{ + @DataProvider + public Object[][] allTypesDataProvider() + { + return new Object[][] { + { + "Trino", + 1, + 493857959588286460L, + PI, + 3.14f, + true, + "ONE", + sqlTimestampOf(3, LocalDateTime.parse("2020-12-12T15:35:45.923")), + "X'65683F'".getBytes(UTF_8) + }, + { + range(0, 5000) + .mapToObj(Integer::toString) + .collect(joining(", ")), + Integer.MAX_VALUE, + Long.MIN_VALUE, + Double.MAX_VALUE, + Float.MIN_VALUE, + false, + "ZERO", + sqlTimestampOf(3, LocalDateTime.parse("1856-01-12T05:25:14.456")), + new byte[0] + }, + { + range(5000, 10000) + .mapToObj(Integer::toString) + .collect(joining(", ")), + Integer.MIN_VALUE, + Long.MAX_VALUE, + Double.NaN, + Float.NEGATIVE_INFINITY, + false, + "ZERO", + sqlTimestampOf(3, LocalDateTime.parse("0001-01-01T00:00:00.923")), + "X'65683F'".getBytes(UTF_8) + } + }; + } +} diff --git a/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java new file mode 100644 index 000000000000..941c1f43365b --- /dev/null +++ b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java @@ -0,0 +1,280 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; +import io.airlift.slice.Slices; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.DecoderTestColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.decoder.RowDecoder; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlVarbinary; +import io.trino.testing.TestingSession; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Set; + +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static io.trino.decoder.util.DecoderTestUtil.checkValue; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.TypeSignature.mapType; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public class TestProtobufDecoder +{ + private static final ProtobufRowDecoderFactory DECODER_FACTORY = new ProtobufRowDecoderFactory(new FixedSchemaDynamicMessageProvider.Factory()); + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testAllDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle stringColumn = new DecoderTestColumnHandle(0, "stringColumn", createVarcharType(30000), "stringColumn", null, null, false, false, false); + DecoderTestColumnHandle integerColumn = new DecoderTestColumnHandle(1, "integerColumn", INTEGER, "integerColumn", null, null, false, false, false); + DecoderTestColumnHandle longColumn = new DecoderTestColumnHandle(2, "longColumn", BIGINT, "longColumn", null, null, false, false, false); + DecoderTestColumnHandle doubleColumn = new DecoderTestColumnHandle(3, "doubleColumn", DOUBLE, "doubleColumn", null, null, false, false, false); + DecoderTestColumnHandle floatColumn = new DecoderTestColumnHandle(4, "floatColumn", REAL, "floatColumn", null, null, false, false, false); + DecoderTestColumnHandle booleanColumn = new DecoderTestColumnHandle(5, "booleanColumn", BOOLEAN, "booleanColumn", null, null, false, false, false); + DecoderTestColumnHandle numberColumn = new DecoderTestColumnHandle(6, "numberColumn", createVarcharType(4), "numberColumn", null, null, false, false, false); + DecoderTestColumnHandle timestampColumn = new DecoderTestColumnHandle(7, "timestampColumn", createTimestampType(3), "timestampColumn", null, null, false, false, false); + DecoderTestColumnHandle bytesColumn = new DecoderTestColumnHandle(8, "bytesColumn", VARBINARY, "bytesColumn", null, null, false, false, false); + + Descriptor descriptor = getDescriptor("all_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + messageBuilder.setField(descriptor.findFieldByName("stringColumn"), stringData); + messageBuilder.setField(descriptor.findFieldByName("integerColumn"), integerData); + messageBuilder.setField(descriptor.findFieldByName("longColumn"), longData); + messageBuilder.setField(descriptor.findFieldByName("doubleColumn"), doubleData); + messageBuilder.setField(descriptor.findFieldByName("floatColumn"), floatData); + messageBuilder.setField(descriptor.findFieldByName("booleanColumn"), booleanData); + messageBuilder.setField(descriptor.findFieldByName("numberColumn"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + messageBuilder.setField(descriptor.findFieldByName("timestampColumn"), getTimestamp(sqlTimestamp)); + messageBuilder.setField(descriptor.findFieldByName("bytesColumn"), bytesData); + + Map decodedRow = createRowDecoder("all_datatypes.proto", ImmutableSet.of(stringColumn, integerColumn, longColumn, doubleColumn, floatColumn, booleanColumn, numberColumn, timestampColumn, bytesColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 9); + + checkValue(decodedRow, stringColumn, stringData); + checkValue(decodedRow, integerColumn, integerData); + checkValue(decodedRow, longColumn, longData); + checkValue(decodedRow, doubleColumn, doubleData); + checkValue(decodedRow, floatColumn, floatData); + checkValue(decodedRow, booleanColumn, booleanData); + checkValue(decodedRow, numberColumn, enumData); + checkValue(decodedRow, timestampColumn, sqlTimestamp.getEpochMicros()); + checkValue(decodedRow, bytesColumn, Slices.wrappedBuffer(bytesData)); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testStructuralDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle listColumn = new DecoderTestColumnHandle(0, "list", new ArrayType(createVarcharType(100)), "list", null, null, false, false, false); + DecoderTestColumnHandle mapColumn = new DecoderTestColumnHandle(1, "map", TESTING_TYPE_MANAGER.getType(mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())), "map", null, null, false, false, false); + DecoderTestColumnHandle rowColumn = new DecoderTestColumnHandle( + 2, + "row", + RowType.from(ImmutableList.builder() + .add(RowType.field("string_column", createVarcharType(30000))) + .add(RowType.field("integer_column", INTEGER)) + .add(RowType.field("long_column", BIGINT)) + .add(RowType.field("double_column", DOUBLE)) + .add(RowType.field("float_column", REAL)) + .add(RowType.field("boolean_column", BOOLEAN)) + .add(RowType.field("number_column", createVarcharType(4))) + .add(RowType.field("timestamp_column", createTimestampType(6))) + .add(RowType.field("bytes_column", VARBINARY)) + .build()), + "row", + null, + null, + false, + false, + false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + messageBuilder.setField(descriptor.findFieldByName("list"), ImmutableList.of("Presto")); + + Descriptor mapDescriptor = descriptor.findFieldByName("map").getMessageType(); + DynamicMessage.Builder mapBuilder = DynamicMessage.newBuilder(mapDescriptor); + mapBuilder.setField(mapDescriptor.findFieldByName("key"), "Key"); + mapBuilder.setField(mapDescriptor.findFieldByName("value"), "Value"); + messageBuilder.setField(descriptor.findFieldByName("map"), ImmutableList.of(mapBuilder.build())); + + Descriptor rowDescriptor = descriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(listColumn, mapColumn, rowColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 3); + + Block listBlock = decodedRow.get(listColumn).getBlock(); + assertEquals(VARCHAR.getSlice(listBlock, 0).toStringUtf8(), "Presto"); + + Block mapBlock = decodedRow.get(mapColumn).getBlock(); + assertEquals(VARCHAR.getSlice(mapBlock, 0).toStringUtf8(), "Key"); + assertEquals(VARCHAR.getSlice(mapBlock, 1).toStringUtf8(), "Value"); + + Block rowBlock = decodedRow.get(rowColumn).getBlock(); + ConnectorSession session = TestingSession.testSessionBuilder().build().toConnectorSession(); + assertEquals(VARCHAR.getObjectValue(session, rowBlock, 0), stringData); + assertEquals(INTEGER.getObjectValue(session, rowBlock, 1), integerData); + assertEquals(BIGINT.getObjectValue(session, rowBlock, 2), longData); + assertEquals(DOUBLE.getObjectValue(session, rowBlock, 3), doubleData); + assertEquals(REAL.getObjectValue(session, rowBlock, 4), floatData); + assertEquals(BOOLEAN.getObjectValue(session, rowBlock, 5), booleanData); + assertEquals(VARCHAR.getObjectValue(session, rowBlock, 6), enumData); + assertEquals(TIMESTAMP_MICROS.getObjectValue(session, rowBlock, 7), sqlTimestamp.roundTo(6)); + assertEquals(VARBINARY.getObjectValue(session, rowBlock, 8), new SqlVarbinary(bytesData)); + } + + @Test + public void testMissingFieldInRowType() + throws Exception + { + DecoderTestColumnHandle rowColumn = new DecoderTestColumnHandle( + 2, + "row", + RowType.from(ImmutableList.of(RowType.field("unknown_mapping", createVarcharType(30000)))), + "row", + null, + null, + false, + false, + false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + Descriptor rowDescriptor = descriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), "Test"); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(rowColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertThatThrownBy(() -> decodedRow.get(rowColumn).getBlock()) + .hasMessageMatching("Unknown Field unknown_mapping"); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testRowFlattening(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle stringColumn = new DecoderTestColumnHandle(0, "stringColumn", createVarcharType(30000), "row/string_column", null, null, false, false, false); + DecoderTestColumnHandle integerColumn = new DecoderTestColumnHandle(1, "integerColumn", INTEGER, "row/integer_column", null, null, false, false, false); + DecoderTestColumnHandle longColumn = new DecoderTestColumnHandle(2, "longColumn", BIGINT, "row/long_column", null, null, false, false, false); + DecoderTestColumnHandle doubleColumn = new DecoderTestColumnHandle(3, "doubleColumn", DOUBLE, "row/double_column", null, null, false, false, false); + DecoderTestColumnHandle floatColumn = new DecoderTestColumnHandle(4, "floatColumn", REAL, "row/float_column", null, null, false, false, false); + DecoderTestColumnHandle booleanColumn = new DecoderTestColumnHandle(5, "booleanColumn", BOOLEAN, "row/boolean_column", null, null, false, false, false); + DecoderTestColumnHandle numberColumn = new DecoderTestColumnHandle(6, "numberColumn", createVarcharType(4), "row/number_column", null, null, false, false, false); + DecoderTestColumnHandle timestampColumn = new DecoderTestColumnHandle(6, "timestampColumn", createTimestampType(3), "row/timestamp_column", null, null, false, false, false); + DecoderTestColumnHandle bytesColumn = new DecoderTestColumnHandle(5, "bytesColumn", VARBINARY, "row/bytes_column", null, null, false, false, false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + Descriptor rowDescriptor = descriptor.findNestedTypeByName("Row"); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(stringColumn, integerColumn, longColumn, doubleColumn, floatColumn, booleanColumn, numberColumn, timestampColumn, bytesColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 9); + + checkValue(decodedRow, stringColumn, stringData); + checkValue(decodedRow, integerColumn, integerData); + checkValue(decodedRow, longColumn, longData); + checkValue(decodedRow, doubleColumn, doubleData); + checkValue(decodedRow, floatColumn, floatData); + checkValue(decodedRow, booleanColumn, booleanData); + checkValue(decodedRow, numberColumn, enumData); + checkValue(decodedRow, timestampColumn, sqlTimestamp.getEpochMicros()); + checkValue(decodedRow, bytesColumn, Slices.wrappedBuffer(bytesData)); + } + + private Timestamp getTimestamp(SqlTimestamp sqlTimestamp) + { + return Timestamp.newBuilder() + .setSeconds(floorDiv(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND)) + .setNanos(floorMod(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND) + .build(); + } + + private RowDecoder createRowDecoder(String fileName, Set columns) + throws Exception + { + return DECODER_FACTORY.create(ImmutableMap.of("dataSchema", ProtobufUtils.getProtoFile("decoder/protobuf/" + fileName)), columns); + } + + private Descriptor getDescriptor(String fileName) + throws Exception + { + return ProtobufUtils.getFileDescriptor(ProtobufUtils.getProtoFile("decoder/protobuf/" + fileName)).findMessageTypeByName(DEFAULT_MESSAGE); + } +} diff --git a/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto new file mode 100644 index 000000000000..06b9c5421c89 --- /dev/null +++ b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + string stringColumn = 1 ; + uint32 integerColumn = 2; + uint64 longColumn = 3; + double doubleColumn = 4; + float floatColumn = 5; + bool booleanColumn = 6; + enum Number { + ZERO = 0; + ONE = 1; + }; + Number numberColumn = 7; + google.protobuf.Timestamp timestampColumn = 8; + bytes bytesColumn = 9; +} diff --git a/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto new file mode 100644 index 000000000000..8e3083e2763a --- /dev/null +++ b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + repeated string list = 1; + map map = 2; + enum Number { + ZERO = 0; + ONE = 1; + TWO = 2; + }; + message Row { + string string_column = 1; + uint32 integer_column = 2; + uint64 long_column = 3; + double double_column = 4; + float float_column = 5; + bool boolean_column = 6; + Number number_column = 7; + google.protobuf.Timestamp timestamp_column = 8; + bytes bytes_column = 9; + }; + Row row = 3; + message NestedRow { + repeated Row nested_list = 1; + map nested_map = 2; + Row row = 3; + }; + NestedRow nested_row = 4; +} diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index 35685c73bb37..30095000975d 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -272,4 +272,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/plugin/trino-kinesis/pom.xml b/plugin/trino-kinesis/pom.xml index 398f7b567d21..2c4febf98b77 100644 --- a/plugin/trino-kinesis/pom.xml +++ b/plugin/trino-kinesis/pom.xml @@ -186,6 +186,16 @@ + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + diff --git a/plugin/trino-redis/pom.xml b/plugin/trino-redis/pom.xml index e2720b124d17..deee4c884882 100644 --- a/plugin/trino-redis/pom.xml +++ b/plugin/trino-redis/pom.xml @@ -201,4 +201,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/pom.xml b/pom.xml index 1b21c263270d..e8ef46ac2465 100644 --- a/pom.xml +++ b/pom.xml @@ -68,6 +68,9 @@ 7.1.4 1.0.0 4.7.2 + 3.21.6 + 3.2.2 + 1.4.0 72 @@ -1182,6 +1185,18 @@ ${dep.errorprone.version} + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + + com.google.protobuf + protobuf-java-util + ${dep.protobuf.version} + + com.h2database h2 @@ -1301,6 +1316,13 @@ ${dep.okhttp.version} + + com.squareup.wire + wire-schema + ${dep.wire.version} + + + com.teradata re2j-td @@ -1725,6 +1747,18 @@ 19.0.0 + + org.jetbrains.kotlin + kotlin-stdlib + ${dep.kotlin.version} + + + + org.jetbrains.kotlin + kotlin-stdlib-common + ${dep.kotlin.version} + + org.jgrapht jgrapht-core