From b798bdec15951ddb2cdaeb45d118dc8c34cd9d05 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Tue, 8 Nov 2022 09:30:37 -0500 Subject: [PATCH 1/2] Add Protobuf decoder --- lib/trino-record-decoder/pom.xml | 31 ++ .../java/io/trino/decoder/DecoderModule.java | 2 + .../protobuf/DynamicMessageProvider.java | 28 ++ .../FixedSchemaDynamicMessageProvider.java | 67 ++++ .../protobuf/ProtobufColumnDecoder.java | 147 ++++++++ .../protobuf/ProtobufDecoderModule.java | 32 ++ .../decoder/protobuf/ProtobufErrorCode.java | 43 +++ .../decoder/protobuf/ProtobufRowDecoder.java | 55 +++ .../protobuf/ProtobufRowDecoderFactory.java | 49 +++ .../trino/decoder/protobuf/ProtobufUtils.java | 272 +++++++++++++++ .../protobuf/ProtobufValueProvider.java | 322 ++++++++++++++++++ .../protobuf/ProtobufDataProviders.java | 71 ++++ .../decoder/protobuf/TestProtobufDecoder.java | 280 +++++++++++++++ .../decoder/protobuf/all_datatypes.proto | 19 ++ .../protobuf/structural_datatypes.proto | 31 ++ plugin/trino-kafka/pom.xml | 15 + plugin/trino-kinesis/pom.xml | 10 + plugin/trino-redis/pom.xml | 15 + pom.xml | 34 ++ 19 files changed, 1523 insertions(+) create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java create mode 100644 lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java create mode 100644 lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java create mode 100644 lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java create mode 100644 lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto create mode 100644 lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto diff --git a/lib/trino-record-decoder/pom.xml b/lib/trino-record-decoder/pom.xml index cf29152408a8..996425a5eeb4 100644 --- a/lib/trino-record-decoder/pom.xml +++ b/lib/trino-record-decoder/pom.xml @@ -33,6 +33,12 @@ jackson-databind + + com.google.code.findbugs + jsr305 + true + + com.google.guava guava @@ -43,6 +49,16 @@ guice + + com.google.protobuf + protobuf-java + + + + com.squareup.wire + wire-schema + + javax.inject javax.inject @@ -107,4 +123,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java index 75c4c519c72a..d5702fea6031 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/DecoderModule.java @@ -23,6 +23,7 @@ import io.trino.decoder.dummy.DummyRowDecoderFactory; import io.trino.decoder.json.JsonRowDecoder; import io.trino.decoder.json.JsonRowDecoderFactory; +import io.trino.decoder.protobuf.ProtobufDecoderModule; import io.trino.decoder.raw.RawRowDecoder; import io.trino.decoder.raw.RawRowDecoderFactory; @@ -43,6 +44,7 @@ public void configure(Binder binder) decoderFactoriesByName.addBinding(JsonRowDecoder.NAME).to(JsonRowDecoderFactory.class).in(SINGLETON); decoderFactoriesByName.addBinding(RawRowDecoder.NAME).to(RawRowDecoderFactory.class).in(SINGLETON); binder.install(new AvroDecoderModule()); + binder.install(new ProtobufDecoderModule()); binder.bind(DispatchingRowDecoderFactory.class).in(SINGLETON); } } diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java new file mode 100644 index 000000000000..bac8ad6761ee --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/DynamicMessageProvider.java @@ -0,0 +1,28 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.DynamicMessage; + +import java.util.Optional; + +public interface DynamicMessageProvider +{ + DynamicMessage parseDynamicMessage(byte[] data); + + interface Factory + { + DynamicMessageProvider create(Optional protoFile); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java new file mode 100644 index 000000000000..61b6c92305a6 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/FixedSchemaDynamicMessageProvider.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.InvalidProtocolBufferException; +import io.trino.spi.TrinoException; + +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkState; +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FixedSchemaDynamicMessageProvider + implements DynamicMessageProvider +{ + private final Descriptor descriptor; + + public FixedSchemaDynamicMessageProvider(Descriptor descriptor) + { + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + } + + @Override + public DynamicMessage parseDynamicMessage(byte[] data) + { + try { + return DynamicMessage.parseFrom(descriptor, data); + } + catch (InvalidProtocolBufferException e) { + throw new TrinoException(ProtobufErrorCode.INVALID_PROTOBUF_MESSAGE, "Decoding Protobuf record failed.", e); + } + } + + public static class Factory + implements DynamicMessageProvider.Factory + { + @Override + public DynamicMessageProvider create(Optional protoFile) + { + checkState(protoFile.isPresent(), "proto file is missing"); + try { + Descriptor descriptor = ProtobufUtils.getFileDescriptor(protoFile.orElseThrow()).findMessageTypeByName(DEFAULT_MESSAGE); + checkState(descriptor != null, format("Message %s not found", DEFAULT_MESSAGE)); + return new FixedSchemaDynamicMessageProvider(descriptor); + } + catch (DescriptorValidationException descriptorValidationException) { + throw new TrinoException(ProtobufErrorCode.INVALID_PROTO_FILE, "Unable to parse protobuf schema", descriptorValidationException); + } + } + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java new file mode 100644 index 000000000000..56118db9f00e --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufColumnDecoder.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.spi.TrinoException; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import javax.annotation.Nullable; + +import java.util.List; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.spi.StandardErrorCode.GENERIC_USER_ERROR; +import static java.util.Objects.requireNonNull; + +public class ProtobufColumnDecoder +{ + private static final Set SUPPORTED_PRIMITIVE_TYPES = ImmutableSet.of( + BooleanType.BOOLEAN, + TinyintType.TINYINT, + SmallintType.SMALLINT, + IntegerType.INTEGER, + BigintType.BIGINT, + RealType.REAL, + DoubleType.DOUBLE, + VarbinaryType.VARBINARY); + + private final Type columnType; + private final String columnMapping; + private final String columnName; + + public ProtobufColumnDecoder(DecoderColumnHandle columnHandle) + { + try { + requireNonNull(columnHandle, "columnHandle is null"); + this.columnType = columnHandle.getType(); + this.columnMapping = columnHandle.getMapping(); + this.columnName = columnHandle.getName(); + checkArgument(!columnHandle.isInternal(), "unexpected internal column '%s'", columnName); + checkArgument(columnHandle.getFormatHint() == null, "unexpected format hint '%s' defined for column '%s'", columnHandle.getFormatHint(), columnName); + checkArgument(columnHandle.getDataFormat() == null, "unexpected data format '%s' defined for column '%s'", columnHandle.getDataFormat(), columnName); + checkArgument(columnHandle.getMapping() != null, "mapping not defined for column '%s'", columnName); + + checkArgument(isSupportedType(columnType), "Unsupported column type '%s' for column '%s'", columnType, columnName); + } + catch (IllegalArgumentException e) { + throw new TrinoException(GENERIC_USER_ERROR, e); + } + } + + private static boolean isSupportedType(Type type) + { + if (isSupportedPrimitive(type)) { + return true; + } + + if (type instanceof ArrayType) { + checkArgument(type.getTypeParameters().size() == 1, "expecting exactly one type parameter for array"); + return isSupportedType(type.getTypeParameters().get(0)); + } + + if (type instanceof MapType) { + List typeParameters = type.getTypeParameters(); + checkArgument(typeParameters.size() == 2, "expecting exactly two type parameters for map"); + return isSupportedType(typeParameters.get(0)) && isSupportedType(type.getTypeParameters().get(1)); + } + + if (type instanceof RowType) { + for (Type fieldType : type.getTypeParameters()) { + if (!isSupportedType(fieldType)) { + return false; + } + } + return true; + } + return false; + } + + private static boolean isSupportedPrimitive(Type type) + { + return (type instanceof TimestampType && ((TimestampType) type).isShort()) || + type instanceof VarcharType || + SUPPORTED_PRIMITIVE_TYPES.contains(type); + } + + public FieldValueProvider decodeField(DynamicMessage dynamicMessage) + { + return new ProtobufValueProvider(locateField(dynamicMessage, columnMapping), columnType, columnName); + } + + @Nullable + private static Object locateField(DynamicMessage message, String columnMapping) + { + Object value = message; + Optional valueDescriptor = Optional.of(message.getDescriptorForType()); + for (String pathElement : Splitter.on('/').omitEmptyStrings().split(columnMapping)) { + if (valueDescriptor.filter(descriptor -> descriptor.findFieldByName(pathElement) != null).isEmpty()) { + return null; + } + FieldDescriptor fieldDescriptor = valueDescriptor.get().findFieldByName(pathElement); + value = ((DynamicMessage) value).getField(fieldDescriptor); + valueDescriptor = getDescriptor(fieldDescriptor); + } + return value; + } + + private static Optional getDescriptor(FieldDescriptor fieldDescriptor) + { + if (fieldDescriptor.getJavaType() == FieldDescriptor.JavaType.MESSAGE) { + return Optional.of(fieldDescriptor.getMessageType()); + } + return Optional.empty(); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java new file mode 100644 index 000000000000..e924905128aa --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufDecoderModule.java @@ -0,0 +1,32 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.inject.Binder; +import com.google.inject.Module; +import io.trino.decoder.RowDecoderFactory; + +import static com.google.inject.Scopes.SINGLETON; +import static com.google.inject.multibindings.MapBinder.newMapBinder; + +public class ProtobufDecoderModule + implements Module +{ + @Override + public void configure(Binder binder) + { + binder.bind(DynamicMessageProvider.Factory.class).to(FixedSchemaDynamicMessageProvider.Factory.class).in(SINGLETON); + newMapBinder(binder, String.class, RowDecoderFactory.class).addBinding(ProtobufRowDecoder.NAME).to(ProtobufRowDecoderFactory.class).in(SINGLETON); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java new file mode 100644 index 000000000000..7ca0076099ca --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufErrorCode.java @@ -0,0 +1,43 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import io.trino.spi.ErrorCode; +import io.trino.spi.ErrorCodeSupplier; +import io.trino.spi.ErrorType; + +import static io.trino.spi.ErrorType.EXTERNAL; + +public enum ProtobufErrorCode + implements ErrorCodeSupplier +{ + INVALID_PROTO_FILE(0, EXTERNAL), + MESSAGE_NOT_FOUND(1, EXTERNAL), + INVALID_PROTOBUF_MESSAGE(2, EXTERNAL), + INVALID_TIMESTAMP(3, EXTERNAL), + /**/; + + private final ErrorCode errorCode; + + ProtobufErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0606_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java new file mode 100644 index 000000000000..441e040ae1f7 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoder.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.DynamicMessage; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.decoder.RowDecoder; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.UnaryOperator.identity; + +public class ProtobufRowDecoder + implements RowDecoder +{ + public static final String NAME = "protobuf"; + + private final DynamicMessageProvider dynamicMessageProvider; + private final Map columnDecoders; + + public ProtobufRowDecoder(DynamicMessageProvider dynamicMessageProvider, Set columns) + { + this.dynamicMessageProvider = requireNonNull(dynamicMessageProvider, "dynamicMessageSupplier is null"); + this.columnDecoders = columns.stream() + .collect(toImmutableMap( + identity(), + ProtobufColumnDecoder::new)); + } + + @Override + public Optional> decodeRow(byte[] data) + { + DynamicMessage message = dynamicMessageProvider.parseDynamicMessage(data); + return Optional.of(columnDecoders.entrySet().stream() + .collect(toImmutableMap( + Map.Entry::getKey, + entry -> entry.getValue().decodeField(message)))); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java new file mode 100644 index 000000000000..52778be4f1b1 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufRowDecoderFactory.java @@ -0,0 +1,49 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.RowDecoder; +import io.trino.decoder.RowDecoderFactory; +import io.trino.decoder.protobuf.DynamicMessageProvider.Factory; + +import javax.inject.Inject; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static java.util.Objects.requireNonNull; + +public class ProtobufRowDecoderFactory + implements RowDecoderFactory +{ + public static final String DEFAULT_MESSAGE = "schema"; + + private final Factory dynamicMessageProviderFactory; + + @Inject + public ProtobufRowDecoderFactory(Factory dynamicMessageProviderFactory) + { + this.dynamicMessageProviderFactory = requireNonNull(dynamicMessageProviderFactory, "dynamicMessageProviderFactory is null"); + } + + @Override + public RowDecoder create(Map decoderParams, Set columns) + { + return new ProtobufRowDecoder( + dynamicMessageProviderFactory.create(Optional.ofNullable(decoderParams.get("dataSchema"))), + columns); + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java new file mode 100644 index 000000000000..02082c37c612 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java @@ -0,0 +1,272 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.io.Resources; +import com.google.protobuf.DescriptorProtos.DescriptorProto; +import com.google.protobuf.DescriptorProtos.EnumDescriptorProto; +import com.google.protobuf.DescriptorProtos.EnumValueDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto; +import com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Label; +import com.google.protobuf.DescriptorProtos.FileDescriptorProto; +import com.google.protobuf.DescriptorProtos.MessageOptions; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import com.google.protobuf.Descriptors.FileDescriptor; +import com.squareup.wire.schema.Field; +import com.squareup.wire.schema.Location; +import com.squareup.wire.schema.ProtoType; +import com.squareup.wire.schema.internal.parser.EnumConstantElement; +import com.squareup.wire.schema.internal.parser.EnumElement; +import com.squareup.wire.schema.internal.parser.FieldElement; +import com.squareup.wire.schema.internal.parser.MessageElement; +import com.squareup.wire.schema.internal.parser.ProtoFileElement; +import com.squareup.wire.schema.internal.parser.ProtoParser; +import com.squareup.wire.schema.internal.parser.TypeElement; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.util.HashSet; +import java.util.Optional; +import java.util.Set; + +import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE; +import static com.google.common.base.CaseFormat.UPPER_CAMEL; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_BOOL; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_BYTES; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_DOUBLE; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_ENUM; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FIXED64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_FLOAT; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_GROUP; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_INT64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_MESSAGE; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SFIXED64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_SINT64; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_STRING; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT32; +import static com.google.protobuf.DescriptorProtos.FieldDescriptorProto.Type.TYPE_UINT64; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; + +public final class ProtobufUtils +{ + private ProtobufUtils() + {} + + public static FileDescriptor getFileDescriptor(String protoFile) + throws DescriptorValidationException + { + ProtoFileElement protoFileElement = ProtoParser.Companion.parse(Location.get(""), protoFile); + return getFileDescriptor(Optional.empty(), protoFileElement); + } + + public static FileDescriptor getFileDescriptor(Optional fileName, ProtoFileElement protoFileElement) + throws DescriptorValidationException + { + FileDescriptor[] dependencies = new FileDescriptor[protoFileElement.getImports().size()]; + Set definedMessages = new HashSet<>(); + int index = 0; + for (String importStatement : protoFileElement.getImports()) { + try { + FileDescriptor fileDescriptor = getFileDescriptor(getProtoFile(importStatement)); + fileDescriptor.getMessageTypes().stream() + .map(Descriptor::getFullName) + .forEach(definedMessages::add); + dependencies[index] = fileDescriptor; + index++; + } + catch (IOException exception) { + throw new UncheckedIOException(exception); + } + } + + FileDescriptorProto.Builder builder = FileDescriptorProto.newBuilder(); + + if (protoFileElement.getSyntax() != null) { + builder.setSyntax(protoFileElement.getSyntax().name()); + } + + fileName.ifPresent(builder::setName); + builder.addAllDependency(protoFileElement.getImports()); + + if (protoFileElement.getPackageName() != null) { + builder.setPackage(protoFileElement.getPackageName()); + } + + for (TypeElement element : protoFileElement.getTypes()) { + if (element instanceof MessageElement) { + builder.addMessageType(processMessage((MessageElement) element, definedMessages)); + definedMessages.add(element.getName()); + } + if (element instanceof EnumElement) { + builder.addEnumType(processEnum((EnumElement) element)); + } + } + + return FileDescriptor.buildFrom(builder.build(), dependencies); + } + + public static String getProtoFile(String filePath) + throws IOException + { + return Resources.toString(Resources.getResource(ProtobufUtils.class, "/" + filePath), UTF_8); + } + + private static DescriptorProto processMessage(MessageElement message, Set globallyDefinedMessages) + { + DescriptorProto.Builder builder = DescriptorProto.newBuilder(); + builder.setName(message.getName()); + Set definedMessages = new HashSet<>(globallyDefinedMessages); + for (TypeElement typeElement : message.getNestedTypes()) { + if (typeElement instanceof EnumElement) { + builder.addEnumType(processEnum((EnumElement) typeElement)); + } + if (typeElement instanceof MessageElement) { + builder.addNestedType(processMessage((MessageElement) typeElement, definedMessages)); + definedMessages.add(typeElement.getName()); + } + } + for (FieldElement field : message.getFields()) { + ProtoType protoType = ProtoType.get(field.getType()); + FieldDescriptorProto.Builder fieldDescriptor = FieldDescriptorProto.newBuilder() + .setName(field.getName()) + .setNumber(field.getTag()); + if (protoType.isMap()) { + requireNonNull(protoType.getKeyType(), "keyType is null"); + requireNonNull(protoType.getValueType(), "valueType is null"); + builder.addNestedType(DescriptorProto.newBuilder() + //First Name to Upper case + .setName(getNameForMapField(field.getName())) + .setOptions(MessageOptions.newBuilder().setMapEntry(true).build()) + .addField( + processType( + FieldDescriptorProto.newBuilder() + .setName("key") + .setNumber(1) + .setLabel(Label.LABEL_OPTIONAL), + protoType.getKeyType(), + definedMessages)) + .addField( + processType( + FieldDescriptorProto.newBuilder() + .setName("value") + .setNumber(2) + .setLabel(Label.LABEL_OPTIONAL), + protoType.getValueType(), + definedMessages)) + .build()); + // Handle for underscores and name + fieldDescriptor.setType(TYPE_MESSAGE) + .setLabel(Label.LABEL_REPEATED) + .setTypeName(getNameForMapField(field.getName())); + } + else { + processType(fieldDescriptor, protoType, definedMessages); + } + if (field.getLabel() != null && field.getLabel() != Field.Label.ONE_OF) { + fieldDescriptor.setLabel(getLabel(field.getLabel())); + } + if (field.getDefaultValue() != null) { + fieldDescriptor.setDefaultValue(field.getDefaultValue()); + } + builder.addField(fieldDescriptor.build()); + } + return builder.build(); + } + + private static EnumDescriptorProto processEnum(EnumElement enumElement) + { + EnumDescriptorProto.Builder enumBuilder = EnumDescriptorProto.newBuilder(); + enumBuilder.setName(enumElement.getName()); + for (EnumConstantElement enumConstant : enumElement.getConstants()) { + enumBuilder.addValue(EnumValueDescriptorProto.newBuilder() + .setName(enumConstant.getName()) + .setNumber(enumConstant.getTag()) + .build()); + } + return enumBuilder.build(); + } + + public static FieldDescriptorProto.Builder processType(FieldDescriptorProto.Builder builder, ProtoType type, Set messageNames) + { + switch (type.getSimpleName()) { + case "double" : + return builder.setType(TYPE_DOUBLE); + case "float" : + return builder.setType(TYPE_FLOAT); + case "int64" : + return builder.setType(TYPE_INT64); + case "uint64" : + return builder.setType(TYPE_UINT64); + case "int32" : + return builder.setType(TYPE_INT32); + case "fixed64" : + return builder.setType(TYPE_FIXED64); + case "fixed32" : + return builder.setType(TYPE_FIXED32); + case "bool" : + return builder.setType(TYPE_BOOL); + case "string" : + return builder.setType(TYPE_STRING); + case "group" : + return builder.setType(TYPE_GROUP); + case "bytes" : + return builder.setType(TYPE_BYTES); + case "uint32" : + return builder.setType(TYPE_UINT32); + case "sfixed32" : + return builder.setType(TYPE_SFIXED32); + case "sfixed64" : + return builder.setType(TYPE_SFIXED64); + case "sint32" : + return builder.setType(TYPE_SINT32); + case "sint64" : + return builder.setType(TYPE_SINT64); + default: { + builder.setTypeName(type.toString()); + if (messageNames.contains(type.toString())) { + builder.setType(TYPE_MESSAGE); + } + else { + builder.setType(TYPE_ENUM); + } + return builder; + } + } + } + + public static Label getLabel(Field.Label label) + { + switch (label) { + case OPTIONAL: + return Label.LABEL_OPTIONAL; + case REPEATED: + return Label.LABEL_REPEATED; + case REQUIRED: + return Label.LABEL_REQUIRED; + default: + throw new IllegalArgumentException("Unknown label"); + } + } + + private static String getNameForMapField(String fieldName) + { + return LOWER_UNDERSCORE.to(UPPER_CAMEL, fieldName) + "Entry"; + } +} diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java new file mode 100644 index 000000000000..a7f80a2d4d88 --- /dev/null +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufValueProvider.java @@ -0,0 +1,322 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors.EnumValueDescriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import io.trino.decoder.FieldValueProvider; +import io.trino.spi.TrinoException; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.BigintType; +import io.trino.spi.type.BooleanType; +import io.trino.spi.type.DoubleType; +import io.trino.spi.type.IntegerType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RealType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SmallintType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TinyintType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import javax.annotation.Nullable; + +import java.util.Collection; +import java.util.List; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.decoder.DecoderErrorCode.DECODER_CONVERSION_NOT_SUPPORTED; +import static io.trino.spi.type.TimestampType.MAX_SHORT_PRECISION; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.round; +import static io.trino.spi.type.Timestamps.roundDiv; +import static io.trino.spi.type.Varchars.truncateToLength; +import static java.lang.Float.floatToIntBits; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ProtobufValueProvider + extends FieldValueProvider +{ + @Nullable + private final Object value; + private final Type columnType; + private final String columnName; + + public ProtobufValueProvider(@Nullable Object value, Type columnType, String columnName) + { + this.value = value; + this.columnType = requireNonNull(columnType, "columnType is null"); + this.columnName = requireNonNull(columnName, "columnName is null"); + } + + @Override + public boolean isNull() + { + return value == null; + } + + @Override + public double getDouble() + { + requireNonNull(value, "value is null"); + if (value instanceof Double || value instanceof Float) { + return ((Number) value).doubleValue(); + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public boolean getBoolean() + { + requireNonNull(value, "value is null"); + if (value instanceof Boolean) { + return (Boolean) value; + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public long getLong() + { + requireNonNull(value, "value is null"); + if (value instanceof Long || value instanceof Integer) { + return ((Number) value).longValue(); + } + if (value instanceof Float) { + return Float.floatToIntBits((Float) value); + } + if (value instanceof DynamicMessage) { + checkArgument(columnType instanceof TimestampType, "type should be an instance of Timestamp"); + return parseTimestamp(((TimestampType) columnType).getPrecision(), (DynamicMessage) value); + } + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), columnType, columnName)); + } + + @Override + public Slice getSlice() + { + return getSlice(value, columnType, columnName); + } + + @Override + public Block getBlock() + { + return serializeObject(null, value, columnType, columnName); + } + + private static Slice getSlice(Object value, Type type, String columnName) + { + requireNonNull(value, "value is null"); + if ((type instanceof VarcharType && value instanceof CharSequence) || value instanceof EnumValueDescriptor) { + return truncateToLength(utf8Slice(value.toString()), type); + } + + if (type instanceof VarbinaryType && value instanceof ByteString) { + return Slices.wrappedBuffer(((ByteString) value).toByteArray()); + } + + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), type, columnName)); + } + + @Nullable + private static Block serializeObject(BlockBuilder builder, Object value, Type type, String columnName) + { + if (type instanceof ArrayType) { + return serializeList(builder, value, type, columnName); + } + if (type instanceof MapType) { + return serializeMap(builder, value, type, columnName); + } + if (type instanceof RowType) { + return serializeRow(builder, value, type, columnName); + } + + serializePrimitive(builder, value, type, columnName); + return null; + } + + @Nullable + private static Block serializeList(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parentBlockBuilder is null"); + parentBlockBuilder.appendNull(); + return null; + } + List list = (List) value; + List typeParameters = type.getTypeParameters(); + Type elementType = typeParameters.get(0); + + BlockBuilder blockBuilder = elementType.createBlockBuilder(null, list.size()); + for (Object element : list) { + serializeObject(blockBuilder, element, elementType, columnName); + } + if (parentBlockBuilder != null) { + type.writeObject(parentBlockBuilder, blockBuilder.build()); + return null; + } + return blockBuilder.build(); + } + + private static void serializePrimitive(BlockBuilder blockBuilder, @Nullable Object value, Type type, String columnName) + { + requireNonNull(blockBuilder, "parent blockBuilder is null"); + + if (value == null) { + blockBuilder.appendNull(); + return; + } + + if (type instanceof BooleanType) { + type.writeBoolean(blockBuilder, (Boolean) value); + return; + } + + if ((value instanceof Integer || value instanceof Long) && (type instanceof BigintType || type instanceof IntegerType || type instanceof SmallintType || type instanceof TinyintType)) { + type.writeLong(blockBuilder, ((Number) value).longValue()); + return; + } + + if (type instanceof DoubleType && value instanceof Double) { + type.writeDouble(blockBuilder, (Double) value); + return; + } + + if (type instanceof RealType && value instanceof Float) { + type.writeLong(blockBuilder, floatToIntBits((Float) value)); + return; + } + + if (type instanceof VarcharType || type instanceof VarbinaryType) { + type.writeSlice(blockBuilder, getSlice(value, type, columnName)); + return; + } + + if (type instanceof TimestampType && ((TimestampType) type).isShort()) { + checkArgument(value instanceof DynamicMessage, "value should be an instance of DynamicMessage"); + type.writeLong(blockBuilder, parseTimestamp(((TimestampType) type).getPrecision(), (DynamicMessage) value)); + return; + } + + throw new TrinoException(DECODER_CONVERSION_NOT_SUPPORTED, format("cannot decode object of '%s' as '%s' for column '%s'", value.getClass(), type, columnName)); + } + + @Nullable + private static Block serializeMap(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parentBlockBuilder is null"); + parentBlockBuilder.appendNull(); + return null; + } + + Collection dynamicMessages = ((Collection) value).stream() + .map(DynamicMessage.class::cast) + .collect(toImmutableList()); + List typeParameters = type.getTypeParameters(); + Type keyType = typeParameters.get(0); + Type valueType = typeParameters.get(1); + + BlockBuilder blockBuilder; + if (parentBlockBuilder != null) { + blockBuilder = parentBlockBuilder; + } + else { + blockBuilder = type.createBlockBuilder(null, 1); + } + + BlockBuilder entryBuilder = blockBuilder.beginBlockEntry(); + for (DynamicMessage dynamicMessage : dynamicMessages) { + if (dynamicMessage.getField(dynamicMessage.getDescriptorForType().findFieldByNumber(1)) != null) { + serializeObject(entryBuilder, dynamicMessage.getField(getFieldDescriptor(dynamicMessage, 1)), keyType, columnName); + serializeObject(entryBuilder, dynamicMessage.getField(getFieldDescriptor(dynamicMessage, 2)), valueType, columnName); + } + } + blockBuilder.closeEntry(); + + if (parentBlockBuilder == null) { + return blockBuilder.getObject(0, Block.class); + } + return null; + } + + @Nullable + private static Block serializeRow(BlockBuilder parentBlockBuilder, @Nullable Object value, Type type, String columnName) + { + if (value == null) { + checkState(parentBlockBuilder != null, "parent block builder is null"); + parentBlockBuilder.appendNull(); + return null; + } + + BlockBuilder blockBuilder; + if (parentBlockBuilder != null) { + blockBuilder = parentBlockBuilder; + } + else { + blockBuilder = type.createBlockBuilder(null, 1); + } + BlockBuilder singleRowBuilder = blockBuilder.beginBlockEntry(); + DynamicMessage record = (DynamicMessage) value; + List fields = ((RowType) type).getFields(); + for (RowType.Field field : fields) { + checkState(field.getName().isPresent(), "field name not found"); + FieldDescriptor fieldDescriptor = getFieldDescriptor(record, field.getName().get()); + checkState(fieldDescriptor != null, format("Unknown Field %s", field.getName().get())); + serializeObject( + singleRowBuilder, + record.getField(fieldDescriptor), + field.getType(), + columnName); + } + blockBuilder.closeEntry(); + if (parentBlockBuilder == null) { + return blockBuilder.getObject(0, Block.class); + } + return null; + } + + private static long parseTimestamp(int precision, DynamicMessage timestamp) + { + long seconds = (Long) timestamp.getField(timestamp.getDescriptorForType().findFieldByName("seconds")); + int nanos = (Integer) timestamp.getField(timestamp.getDescriptorForType().findFieldByName("nanos")); + long micros = seconds * MICROSECONDS_PER_SECOND; + micros += roundDiv(nanos, NANOSECONDS_PER_MICROSECOND); + checkArgument(precision <= MAX_SHORT_PRECISION, "precision must be less than max short timestamp precision (" + MAX_SHORT_PRECISION + ")"); + return round(micros, MAX_SHORT_PRECISION - precision); + } + + private static FieldDescriptor getFieldDescriptor(DynamicMessage message, String name) + { + return message.getDescriptorForType().findFieldByName(name); + } + + private static FieldDescriptor getFieldDescriptor(DynamicMessage message, int index) + { + return message.getDescriptorForType().findFieldByNumber(index); + } +} diff --git a/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java new file mode 100644 index 000000000000..d9383833d32d --- /dev/null +++ b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/ProtobufDataProviders.java @@ -0,0 +1,71 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import org.testng.annotations.DataProvider; + +import java.time.LocalDateTime; + +import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf; +import static java.lang.Math.PI; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.stream.Collectors.joining; +import static java.util.stream.IntStream.range; + +public class ProtobufDataProviders +{ + @DataProvider + public Object[][] allTypesDataProvider() + { + return new Object[][] { + { + "Trino", + 1, + 493857959588286460L, + PI, + 3.14f, + true, + "ONE", + sqlTimestampOf(3, LocalDateTime.parse("2020-12-12T15:35:45.923")), + "X'65683F'".getBytes(UTF_8) + }, + { + range(0, 5000) + .mapToObj(Integer::toString) + .collect(joining(", ")), + Integer.MAX_VALUE, + Long.MIN_VALUE, + Double.MAX_VALUE, + Float.MIN_VALUE, + false, + "ZERO", + sqlTimestampOf(3, LocalDateTime.parse("1856-01-12T05:25:14.456")), + new byte[0] + }, + { + range(5000, 10000) + .mapToObj(Integer::toString) + .collect(joining(", ")), + Integer.MIN_VALUE, + Long.MAX_VALUE, + Double.NaN, + Float.NEGATIVE_INFINITY, + false, + "ZERO", + sqlTimestampOf(3, LocalDateTime.parse("0001-01-01T00:00:00.923")), + "X'65683F'".getBytes(UTF_8) + } + }; + } +} diff --git a/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java new file mode 100644 index 000000000000..941c1f43365b --- /dev/null +++ b/lib/trino-record-decoder/src/test/java/io/trino/decoder/protobuf/TestProtobufDecoder.java @@ -0,0 +1,280 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.decoder.protobuf; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; +import io.airlift.slice.Slices; +import io.trino.decoder.DecoderColumnHandle; +import io.trino.decoder.DecoderTestColumnHandle; +import io.trino.decoder.FieldValueProvider; +import io.trino.decoder.RowDecoder; +import io.trino.spi.block.Block; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlVarbinary; +import io.trino.testing.TestingSession; +import org.testng.annotations.Test; + +import java.util.Map; +import java.util.Set; + +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static io.trino.decoder.util.DecoderTestUtil.checkValue; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.TypeSignature.mapType; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; +import static org.assertj.core.api.Assertions.assertThatThrownBy; +import static org.testng.Assert.assertEquals; + +public class TestProtobufDecoder +{ + private static final ProtobufRowDecoderFactory DECODER_FACTORY = new ProtobufRowDecoderFactory(new FixedSchemaDynamicMessageProvider.Factory()); + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testAllDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle stringColumn = new DecoderTestColumnHandle(0, "stringColumn", createVarcharType(30000), "stringColumn", null, null, false, false, false); + DecoderTestColumnHandle integerColumn = new DecoderTestColumnHandle(1, "integerColumn", INTEGER, "integerColumn", null, null, false, false, false); + DecoderTestColumnHandle longColumn = new DecoderTestColumnHandle(2, "longColumn", BIGINT, "longColumn", null, null, false, false, false); + DecoderTestColumnHandle doubleColumn = new DecoderTestColumnHandle(3, "doubleColumn", DOUBLE, "doubleColumn", null, null, false, false, false); + DecoderTestColumnHandle floatColumn = new DecoderTestColumnHandle(4, "floatColumn", REAL, "floatColumn", null, null, false, false, false); + DecoderTestColumnHandle booleanColumn = new DecoderTestColumnHandle(5, "booleanColumn", BOOLEAN, "booleanColumn", null, null, false, false, false); + DecoderTestColumnHandle numberColumn = new DecoderTestColumnHandle(6, "numberColumn", createVarcharType(4), "numberColumn", null, null, false, false, false); + DecoderTestColumnHandle timestampColumn = new DecoderTestColumnHandle(7, "timestampColumn", createTimestampType(3), "timestampColumn", null, null, false, false, false); + DecoderTestColumnHandle bytesColumn = new DecoderTestColumnHandle(8, "bytesColumn", VARBINARY, "bytesColumn", null, null, false, false, false); + + Descriptor descriptor = getDescriptor("all_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + messageBuilder.setField(descriptor.findFieldByName("stringColumn"), stringData); + messageBuilder.setField(descriptor.findFieldByName("integerColumn"), integerData); + messageBuilder.setField(descriptor.findFieldByName("longColumn"), longData); + messageBuilder.setField(descriptor.findFieldByName("doubleColumn"), doubleData); + messageBuilder.setField(descriptor.findFieldByName("floatColumn"), floatData); + messageBuilder.setField(descriptor.findFieldByName("booleanColumn"), booleanData); + messageBuilder.setField(descriptor.findFieldByName("numberColumn"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + messageBuilder.setField(descriptor.findFieldByName("timestampColumn"), getTimestamp(sqlTimestamp)); + messageBuilder.setField(descriptor.findFieldByName("bytesColumn"), bytesData); + + Map decodedRow = createRowDecoder("all_datatypes.proto", ImmutableSet.of(stringColumn, integerColumn, longColumn, doubleColumn, floatColumn, booleanColumn, numberColumn, timestampColumn, bytesColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 9); + + checkValue(decodedRow, stringColumn, stringData); + checkValue(decodedRow, integerColumn, integerData); + checkValue(decodedRow, longColumn, longData); + checkValue(decodedRow, doubleColumn, doubleData); + checkValue(decodedRow, floatColumn, floatData); + checkValue(decodedRow, booleanColumn, booleanData); + checkValue(decodedRow, numberColumn, enumData); + checkValue(decodedRow, timestampColumn, sqlTimestamp.getEpochMicros()); + checkValue(decodedRow, bytesColumn, Slices.wrappedBuffer(bytesData)); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testStructuralDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle listColumn = new DecoderTestColumnHandle(0, "list", new ArrayType(createVarcharType(100)), "list", null, null, false, false, false); + DecoderTestColumnHandle mapColumn = new DecoderTestColumnHandle(1, "map", TESTING_TYPE_MANAGER.getType(mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())), "map", null, null, false, false, false); + DecoderTestColumnHandle rowColumn = new DecoderTestColumnHandle( + 2, + "row", + RowType.from(ImmutableList.builder() + .add(RowType.field("string_column", createVarcharType(30000))) + .add(RowType.field("integer_column", INTEGER)) + .add(RowType.field("long_column", BIGINT)) + .add(RowType.field("double_column", DOUBLE)) + .add(RowType.field("float_column", REAL)) + .add(RowType.field("boolean_column", BOOLEAN)) + .add(RowType.field("number_column", createVarcharType(4))) + .add(RowType.field("timestamp_column", createTimestampType(6))) + .add(RowType.field("bytes_column", VARBINARY)) + .build()), + "row", + null, + null, + false, + false, + false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + messageBuilder.setField(descriptor.findFieldByName("list"), ImmutableList.of("Presto")); + + Descriptor mapDescriptor = descriptor.findFieldByName("map").getMessageType(); + DynamicMessage.Builder mapBuilder = DynamicMessage.newBuilder(mapDescriptor); + mapBuilder.setField(mapDescriptor.findFieldByName("key"), "Key"); + mapBuilder.setField(mapDescriptor.findFieldByName("value"), "Value"); + messageBuilder.setField(descriptor.findFieldByName("map"), ImmutableList.of(mapBuilder.build())); + + Descriptor rowDescriptor = descriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(listColumn, mapColumn, rowColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 3); + + Block listBlock = decodedRow.get(listColumn).getBlock(); + assertEquals(VARCHAR.getSlice(listBlock, 0).toStringUtf8(), "Presto"); + + Block mapBlock = decodedRow.get(mapColumn).getBlock(); + assertEquals(VARCHAR.getSlice(mapBlock, 0).toStringUtf8(), "Key"); + assertEquals(VARCHAR.getSlice(mapBlock, 1).toStringUtf8(), "Value"); + + Block rowBlock = decodedRow.get(rowColumn).getBlock(); + ConnectorSession session = TestingSession.testSessionBuilder().build().toConnectorSession(); + assertEquals(VARCHAR.getObjectValue(session, rowBlock, 0), stringData); + assertEquals(INTEGER.getObjectValue(session, rowBlock, 1), integerData); + assertEquals(BIGINT.getObjectValue(session, rowBlock, 2), longData); + assertEquals(DOUBLE.getObjectValue(session, rowBlock, 3), doubleData); + assertEquals(REAL.getObjectValue(session, rowBlock, 4), floatData); + assertEquals(BOOLEAN.getObjectValue(session, rowBlock, 5), booleanData); + assertEquals(VARCHAR.getObjectValue(session, rowBlock, 6), enumData); + assertEquals(TIMESTAMP_MICROS.getObjectValue(session, rowBlock, 7), sqlTimestamp.roundTo(6)); + assertEquals(VARBINARY.getObjectValue(session, rowBlock, 8), new SqlVarbinary(bytesData)); + } + + @Test + public void testMissingFieldInRowType() + throws Exception + { + DecoderTestColumnHandle rowColumn = new DecoderTestColumnHandle( + 2, + "row", + RowType.from(ImmutableList.of(RowType.field("unknown_mapping", createVarcharType(30000)))), + "row", + null, + null, + false, + false, + false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + Descriptor rowDescriptor = descriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), "Test"); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(rowColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertThatThrownBy(() -> decodedRow.get(rowColumn).getBlock()) + .hasMessageMatching("Unknown Field unknown_mapping"); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testRowFlattening(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + DecoderTestColumnHandle stringColumn = new DecoderTestColumnHandle(0, "stringColumn", createVarcharType(30000), "row/string_column", null, null, false, false, false); + DecoderTestColumnHandle integerColumn = new DecoderTestColumnHandle(1, "integerColumn", INTEGER, "row/integer_column", null, null, false, false, false); + DecoderTestColumnHandle longColumn = new DecoderTestColumnHandle(2, "longColumn", BIGINT, "row/long_column", null, null, false, false, false); + DecoderTestColumnHandle doubleColumn = new DecoderTestColumnHandle(3, "doubleColumn", DOUBLE, "row/double_column", null, null, false, false, false); + DecoderTestColumnHandle floatColumn = new DecoderTestColumnHandle(4, "floatColumn", REAL, "row/float_column", null, null, false, false, false); + DecoderTestColumnHandle booleanColumn = new DecoderTestColumnHandle(5, "booleanColumn", BOOLEAN, "row/boolean_column", null, null, false, false, false); + DecoderTestColumnHandle numberColumn = new DecoderTestColumnHandle(6, "numberColumn", createVarcharType(4), "row/number_column", null, null, false, false, false); + DecoderTestColumnHandle timestampColumn = new DecoderTestColumnHandle(6, "timestampColumn", createTimestampType(3), "row/timestamp_column", null, null, false, false, false); + DecoderTestColumnHandle bytesColumn = new DecoderTestColumnHandle(5, "bytesColumn", VARBINARY, "row/bytes_column", null, null, false, false, false); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + Descriptor rowDescriptor = descriptor.findNestedTypeByName("Row"); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + Map decodedRow = createRowDecoder("structural_datatypes.proto", ImmutableSet.of(stringColumn, integerColumn, longColumn, doubleColumn, floatColumn, booleanColumn, numberColumn, timestampColumn, bytesColumn)) + .decodeRow(messageBuilder.build().toByteArray()) + .orElseThrow(AssertionError::new); + + assertEquals(decodedRow.size(), 9); + + checkValue(decodedRow, stringColumn, stringData); + checkValue(decodedRow, integerColumn, integerData); + checkValue(decodedRow, longColumn, longData); + checkValue(decodedRow, doubleColumn, doubleData); + checkValue(decodedRow, floatColumn, floatData); + checkValue(decodedRow, booleanColumn, booleanData); + checkValue(decodedRow, numberColumn, enumData); + checkValue(decodedRow, timestampColumn, sqlTimestamp.getEpochMicros()); + checkValue(decodedRow, bytesColumn, Slices.wrappedBuffer(bytesData)); + } + + private Timestamp getTimestamp(SqlTimestamp sqlTimestamp) + { + return Timestamp.newBuilder() + .setSeconds(floorDiv(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND)) + .setNanos(floorMod(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND) + .build(); + } + + private RowDecoder createRowDecoder(String fileName, Set columns) + throws Exception + { + return DECODER_FACTORY.create(ImmutableMap.of("dataSchema", ProtobufUtils.getProtoFile("decoder/protobuf/" + fileName)), columns); + } + + private Descriptor getDescriptor(String fileName) + throws Exception + { + return ProtobufUtils.getFileDescriptor(ProtobufUtils.getProtoFile("decoder/protobuf/" + fileName)).findMessageTypeByName(DEFAULT_MESSAGE); + } +} diff --git a/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto new file mode 100644 index 000000000000..06b9c5421c89 --- /dev/null +++ b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/all_datatypes.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + string stringColumn = 1 ; + uint32 integerColumn = 2; + uint64 longColumn = 3; + double doubleColumn = 4; + float floatColumn = 5; + bool booleanColumn = 6; + enum Number { + ZERO = 0; + ONE = 1; + }; + Number numberColumn = 7; + google.protobuf.Timestamp timestampColumn = 8; + bytes bytesColumn = 9; +} diff --git a/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto new file mode 100644 index 000000000000..8e3083e2763a --- /dev/null +++ b/lib/trino-record-decoder/src/test/resources/decoder/protobuf/structural_datatypes.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + repeated string list = 1; + map map = 2; + enum Number { + ZERO = 0; + ONE = 1; + TWO = 2; + }; + message Row { + string string_column = 1; + uint32 integer_column = 2; + uint64 long_column = 3; + double double_column = 4; + float float_column = 5; + bool boolean_column = 6; + Number number_column = 7; + google.protobuf.Timestamp timestamp_column = 8; + bytes bytes_column = 9; + }; + Row row = 3; + message NestedRow { + repeated Row nested_list = 1; + map nested_map = 2; + Row row = 3; + }; + NestedRow nested_row = 4; +} diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index 35685c73bb37..30095000975d 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -272,4 +272,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/plugin/trino-kinesis/pom.xml b/plugin/trino-kinesis/pom.xml index 398f7b567d21..2c4febf98b77 100644 --- a/plugin/trino-kinesis/pom.xml +++ b/plugin/trino-kinesis/pom.xml @@ -186,6 +186,16 @@ + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + diff --git a/plugin/trino-redis/pom.xml b/plugin/trino-redis/pom.xml index e2720b124d17..deee4c884882 100644 --- a/plugin/trino-redis/pom.xml +++ b/plugin/trino-redis/pom.xml @@ -201,4 +201,19 @@ test + + + + + org.basepom.maven + duplicate-finder-maven-plugin + + + + google/protobuf/.*\.proto$ + + + + + diff --git a/pom.xml b/pom.xml index 2253da1e08ee..44918bfe3e6c 100644 --- a/pom.xml +++ b/pom.xml @@ -68,6 +68,9 @@ 7.1.4 1.0.0 4.7.2 + 3.21.6 + 3.2.2 + 1.4.0 72 @@ -1182,6 +1185,18 @@ ${dep.errorprone.version} + + com.google.protobuf + protobuf-java + ${dep.protobuf.version} + + + + com.google.protobuf + protobuf-java-util + ${dep.protobuf.version} + + com.h2database h2 @@ -1301,6 +1316,13 @@ ${dep.okhttp.version} + + com.squareup.wire + wire-schema + ${dep.wire.version} + + + com.teradata re2j-td @@ -1725,6 +1747,18 @@ 19.0.0 + + org.jetbrains.kotlin + kotlin-stdlib + ${dep.kotlin.version} + + + + org.jetbrains.kotlin + kotlin-stdlib-common + ${dep.kotlin.version} + + org.jgrapht jgrapht-core From 5c40f44e97afc493ea137d8f1c38adf95c474ac9 Mon Sep 17 00:00:00 2001 From: Neville Li Date: Tue, 8 Nov 2022 09:31:55 -0500 Subject: [PATCH 2/2] Add Protobuf support for trino-kafka --- .github/workflows/ci.yml | 2 +- docs/src/main/sphinx/connector/kafka.rst | 186 ++++++++- plugin/trino-kafka/pom.xml | 56 ++- .../plugin/kafka/encoder/EncoderModule.java | 2 + .../protobuf/ProtobufEncoderModule.java | 33 ++ .../encoder/protobuf/ProtobufRowEncoder.java | 356 +++++++++++++++++ .../protobuf/ProtobufRowEncoderFactory.java | 53 +++ .../protobuf/ProtobufSchemaParser.java | 115 ++++++ .../schema/confluent/ConfluentModule.java | 59 +++ ...tSchemaRegistryDynamicMessageProvider.java | 105 +++++ ...ithSchemaRegistryMinimalFunctionality.java | 324 +++++++++++++++ .../kafka/protobuf/TestProtobufEncoder.java | 368 ++++++++++++++++++ .../resources/protobuf/default_values.proto | 15 + .../resources/protobuf/evolved_schema.proto | 7 + .../resources/protobuf/initial_schema.proto | 6 + .../test/resources/protobuf/key_schema.proto | 5 + .../test/resources/protobuf/timestamps.proto | 13 + pom.xml | 25 +- testing/trino-product-tests-launcher/pom.xml | 25 ++ .../product/launcher/env/common/Kafka.java | 8 +- .../kafka.properties | 6 +- .../multinode-kafka-ssl/kafka.properties | 6 +- .../multinode-kafka/kafka.properties | 6 +- .../etc/catalog/kafka/all_datatypes.proto | 19 + .../catalog/kafka/all_datatypes_protobuf.json | 55 +++ .../etc/catalog/kafka/basic_datatypes.proto | 10 + .../kafka/basic_structural_datatypes.proto | 6 + .../kafka/read_basic_datatypes_protobuf.json | 41 ++ ...d_basic_structural_datatypes_protobuf.json | 21 + .../catalog/kafka/structural_datatype.proto | 28 ++ .../kafka/structural_datatype_protobuf.json | 55 +++ .../install-kafka-protobuf-provider.sh | 3 + testing/trino-product-tests/pom.xml | 18 + .../product/kafka/TestKafkaProtobuf.java | 128 ++++++ .../product/kafka/TestKafkaProtobufReads.java | 317 +++++++++++++++ 35 files changed, 2472 insertions(+), 10 deletions(-) create mode 100644 plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufEncoderModule.java create mode 100644 plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoder.java create mode 100644 plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoderFactory.java create mode 100644 plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java create mode 100644 plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentSchemaRegistryDynamicMessageProvider.java create mode 100644 plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java create mode 100644 plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestProtobufEncoder.java create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/default_values.proto create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/evolved_schema.proto create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/initial_schema.proto create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/key_schema.proto create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/timestamps.proto create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes_protobuf.json create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_datatypes_protobuf.json create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_structural_datatypes_protobuf.json create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype.proto create mode 100644 testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype_protobuf.json create mode 100644 testing/trino-product-tests-launcher/src/main/resources/install-kafka-protobuf-provider.sh create mode 100644 testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobuf.java create mode 100644 testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobufReads.java diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index df45678bd870..5f55125a136c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -555,7 +555,7 @@ jobs: path: | core/trino-server/target/*.tar.gz impacted-features.log - testing/trino-product-tests-launcher/target/*-executable.jar + testing/trino-product-tests-launcher/target/*.jar testing/trino-product-tests/target/*-executable.jar client/trino-cli/target/*-executable.jar retention-days: 1 diff --git a/docs/src/main/sphinx/connector/kafka.rst b/docs/src/main/sphinx/connector/kafka.rst index 58ca0f04c4a5..9325d237cd2b 100644 --- a/docs/src/main/sphinx/connector/kafka.rst +++ b/docs/src/main/sphinx/connector/kafka.rst @@ -487,12 +487,13 @@ Kafka inserts The Kafka connector supports the use of :doc:`/sql/insert` statements to write data to a Kafka topic. Table column data is mapped to Kafka messages as defined in the `table definition file <#table-definition-files>`__. There are -four supported data formats for key and message encoding: +five supported data formats for key and message encoding: * `raw format <#raw-encoder>`__ * `CSV format <#csv-encoder>`__ * `JSON format <#json-encoder>`__ * `Avro format <#avro-encoder>`__ +* `Protobuf format <#protobuf-encoder>`__ These data formats each have an encoder that maps column values into bytes to be sent to a Kafka topic. @@ -537,6 +538,8 @@ The Kafka connector contains the following encoders: fields. * `Avro encoder <#avro-encoder>`__ - Table columns are mapped to Avro fields based on an Avro schema. +* `Protobuf encoder <#protobuf-encoder>`__ - Table columns are mapped to + Protobuf fields based on a Protobuf schema. .. note:: @@ -978,11 +981,118 @@ definition is shown: "doc:" : "A basic avro schema" } -The following is an example insert query for the preceding table definition:: +The following is an example insert query for the preceding table definition: INSERT INTO example_avro_table (field1, field2, field3) VALUES (123456789, 'example text', FALSE); +Protobuf encoder +"""""""""""""""" + +The Protobuf encoder serializes rows to Protobuf DynamicMessages as defined by +the `Protobuf schema `_. + +.. note:: + + The Protobuf schema is encoded with the table column values in each Kafka message. + +The ``dataSchema`` must be defined in the table definition file to use the +Protobuf encoder. It points to the location of the ``proto`` file for the key +or message. + +Protobuf schema files can be retrieved via HTTP or HTTPS from a remote server +with the syntax: + +``"dataSchema": "http://example.org/schema/schema.proto"`` + +Local files need to be available on all Trino nodes and use an absolute path in +the syntax, for example: + +``"dataSchema": "/usr/local/schema/schema.proto"`` + +The following field attributes are supported: + +* ``name`` - Name of the column in the Trino table. +* ``type`` - Trino type of column. +* ``mapping`` - slash-separated list of field names to select a field from the + Protobuf schema. If the field specified in ``mapping`` does not exist in the + original Protobuf schema, then a write operation fails. + +The following table lists supported Trino data types, which can be used in ``type`` +for the equivalent Protobuf field type. + +===================================== ======================================= +Trino data type Protobuf data type +===================================== ======================================= +``BOOLEAN`` ``bool`` +``INTEGER`` ``int32``, ``uint32``, ``sint32``, ``fixed32``, ``sfixed32`` +``BIGINT`` ``int64``, ``uint64``, ``sint64``, ``fixed64``, ``sfixed64`` +``DOUBLE`` ``double`` +``REAL`` ``float`` +``VARCHAR`` / ``VARCHAR(x)`` ``string`` +``VARBINARY`` ``bytes`` +``ROW`` ``Message`` +``ARRAY`` Protobuf type with ``repeated`` field +``MAP`` ``Map`` +``TIMESTAMP`` ``Timestamp``, predefined in ``timestamp.proto`` +===================================== ======================================= + +The following example shows a Protobuf field definition in a `table definition +file <#table-definition-files>`__ for a Kafka message: + + +.. code-block:: json + + { + "tableName": "your-table-name", + "schemaName": "your-schema-name", + "topicName": "your-topic-name", + "key": { "..." }, + "message": + { + "dataFormat": "protobuf", + "dataSchema": "/message_schema.proto", + "fields": + [ + { + "name": "field1", + "type": "BIGINT", + "mapping": "field1" + }, + { + "name": "field2", + "type": "VARCHAR", + "mapping": "field2" + }, + { + "name": "field3", + "type": "BOOLEAN", + "mapping": "field3" + } + ] + } + } + +In the following example, a Protobuf schema definition for the preceding table +definition is shown: + +.. code-block:: text + + syntax = "proto3"; + + message schema { + uint64 field1 = 1 ; + string field2 = 2; + bool field3 = 3; + } + +The following is an example insert query for the preceding table definition: + +.. code-block:: sql + + INSERT INTO example_protobuf_table (field1, field2, field3) + VALUES (123456789, 'example text', FALSE); + .. _kafka-row-decoding: Row decoding @@ -996,6 +1106,7 @@ The Kafka connector contains the following decoders: * ``csv`` - Kafka message is interpreted as comma separated message, and fields are mapped to table columns. * ``json`` - Kafka message is parsed as JSON, and JSON fields are mapped to table columns. * ``avro`` - Kafka message is parsed based on an Avro schema, and Avro fields are mapped to table columns. +* ``protobuf`` - Kafka message is parsed based on a Protobuf schema, and Protobuf fields are mapped to table columns. .. note:: @@ -1237,6 +1348,76 @@ The schema evolution behavior is as follows: If the type coercion is supported by Avro, then the conversion happens. An error is thrown for incompatible types. +Protobuf decoder +"""""""""""""""" + +The Protobuf decoder converts the bytes representing a message or key in +Protobuf formatted message based on a schema. + +For key/message, using the ``protobuf`` decoder, the ``dataSchema`` must be +defined. It points to the location of a valid ``proto`` file of the message +which needs to be decoded. This location can be a remote web server, +``dataSchema: 'http://example.org/schema/schema.proto'``, or local file, +``dataSchema: '/usr/local/schema/schema.proto'``. The decoder fails if the +location is not accessible from the coordinator. + +For fields, the following attributes are supported: + +* ``name`` - Name of the column in the Trino table. +* ``type`` - Trino data type of column. +* ``mapping`` - slash-separated list of field names to select a field from the + Protobuf schema. If field specified in ``mapping`` does not exist in the + original ``proto`` file then a read operation returns NULL. + +The following table lists the supported Trino types which can be used in +``type`` for the equivalent Protobuf field types: + +===================================== ======================================= +Trino data type Allowed Protobuf data type +===================================== ======================================= +``BOOLEAN`` ``bool`` +``INTEGER`` ``int32``, ``uint32``, ``sint32``, ``fixed32``, ``sfixed32`` +``BIGINT`` ``int64``, ``uint64``, ``sint64``, ``fixed64``, ``sfixed64`` +``DOUBLE`` ``double`` +``REAL`` ``float`` +``VARCHAR`` / ``VARCHAR(x)`` ``string`` +``VARBINARY`` ``bytes`` +``ROW`` ``Message`` +``ARRAY`` Protobuf type with ``repeated`` field +``MAP`` ``Map`` +``TIMESTAMP`` ``Timestamp``, predefined in ``timestamp.proto`` +===================================== ======================================= + +Protobuf schema evolution ++++++++++++++++++++++++++ + +The Protobuf decoder supports the schema evolution feature with backward +compatibility. With backward compatibility, a newer schema can be used to read +Protobuf data created with an older schema. Any change in the Protobuf schema +*must* also be reflected in the topic definition file. + +The schema evolution behavior is as follows: + +* Column added in new schema: + Data created with an older schema produces a *default* value when the table is using the new schema. + +* Column removed in new schema: + Data created with an older schema no longer outputs the data from the column that was removed. + +* Column is renamed in the new schema: + This is equivalent to removing the column and adding a new one, and data created with an older schema + produces a *default* value when table is using the new schema. + +* Changing type of column in the new schema: + If the type coercion is supported by Protobuf, then the conversion happens. An error is thrown for incompatible types. + +Protobuf limitations +++++++++++++++++++++ + +* Protobuf specific types like ``any``, ``oneof`` are not supported. +* Protobuf Timestamp has a nanosecond precision but Trino supports + decoding/encoding at microsecond precision. + .. _kafka-sql-support: SQL support @@ -1252,4 +1433,3 @@ supports the following features: * :doc:`/sql/insert`, encoded to a specified data format. See also :ref:`kafka-sql-inserts`. - diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index 30095000975d..62d3900dc92f 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -88,6 +88,16 @@ guice + + com.google.protobuf + protobuf-java + + + + com.google.protobuf + protobuf-java-util + + io.confluent kafka-schema-registry-client @@ -135,6 +145,13 @@ runtime + + + com.squareup.wire + wire-schema + runtime + + javax.ws.rs javax.ws.rs-api @@ -166,6 +183,13 @@ provided + + io.confluent + kafka-protobuf-provider + + provided + + org.openjdk.jol jol-core @@ -192,6 +216,13 @@ test + + io.trino + trino-record-decoder + test-jar + test + + io.trino trino-spi @@ -244,10 +275,23 @@ io.confluent kafka-json-schema-serializer - + test + + io.confluent + kafka-protobuf-serializer + + test + + + org.apache.kafka + kafka-clients + + + + io.confluent kafka-schema-serializer @@ -275,6 +319,16 @@ + + io.trino + trino-maven-plugin + true + + + io.confluent:kafka-protobuf-provider + + + org.basepom.maven duplicate-finder-maven-plugin diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/EncoderModule.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/EncoderModule.java index 5a24abbae3cc..bcce05f9b9e6 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/EncoderModule.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/EncoderModule.java @@ -21,6 +21,7 @@ import io.trino.plugin.kafka.encoder.csv.CsvRowEncoderFactory; import io.trino.plugin.kafka.encoder.json.JsonRowEncoder; import io.trino.plugin.kafka.encoder.json.JsonRowEncoderFactory; +import io.trino.plugin.kafka.encoder.protobuf.ProtobufEncoderModule; import io.trino.plugin.kafka.encoder.raw.RawRowEncoder; import io.trino.plugin.kafka.encoder.raw.RawRowEncoderFactory; @@ -39,6 +40,7 @@ public void configure(Binder binder) encoderFactoriesByName.addBinding(RawRowEncoder.NAME).to(RawRowEncoderFactory.class).in(SINGLETON); encoderFactoriesByName.addBinding(JsonRowEncoder.NAME).to(JsonRowEncoderFactory.class).in(SINGLETON); binder.install(new AvroEncoderModule()); + binder.install(new ProtobufEncoderModule()); binder.bind(DispatchingRowEncoderFactory.class).in(SINGLETON); } diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufEncoderModule.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufEncoderModule.java new file mode 100644 index 000000000000..30d765bb6a6c --- /dev/null +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufEncoderModule.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.encoder.protobuf; + +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.multibindings.MapBinder; +import io.trino.plugin.kafka.encoder.RowEncoderFactory; + +import static com.google.inject.Scopes.SINGLETON; +import static com.google.inject.multibindings.MapBinder.newMapBinder; + +public class ProtobufEncoderModule + implements Module +{ + @Override + public void configure(Binder binder) + { + MapBinder encoderFactoriesByName = newMapBinder(binder, String.class, RowEncoderFactory.class); + encoderFactoriesByName.addBinding(ProtobufRowEncoder.NAME).to(ProtobufRowEncoderFactory.class).in(SINGLETON); + } +} diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoder.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoder.java new file mode 100644 index 000000000000..565a516002af --- /dev/null +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoder.java @@ -0,0 +1,356 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.encoder.protobuf; + +import com.google.common.base.Splitter; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import com.google.protobuf.ByteString; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; +import io.trino.plugin.kafka.encoder.AbstractRowEncoder; +import io.trino.plugin.kafka.encoder.EncoderColumnHandle; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.SqlVarbinary; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.Type; +import io.trino.spi.type.VarbinaryType; +import io.trino.spi.type.VarcharType; + +import java.nio.ByteBuffer; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.protobuf.Descriptors.FieldDescriptor.JavaType.ENUM; +import static com.google.protobuf.Descriptors.FieldDescriptor.JavaType.MESSAGE; +import static com.google.protobuf.util.Timestamps.checkValid; +import static io.trino.decoder.protobuf.ProtobufErrorCode.INVALID_TIMESTAMP; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ProtobufRowEncoder + extends AbstractRowEncoder +{ + public static final String NAME = "protobuf"; + + private static final Set SUPPORTED_PRIMITIVE_TYPES = ImmutableSet.of( + BOOLEAN, INTEGER, BIGINT, DOUBLE, REAL); + + private final Descriptor descriptor; + private final DynamicMessage.Builder messageBuilder; + + public ProtobufRowEncoder(Descriptor descriptor, ConnectorSession session, List columnHandles) + { + super(session, columnHandles); + for (EncoderColumnHandle columnHandle : this.columnHandles) { + checkArgument(columnHandle.getFormatHint() == null, "formatHint must be null"); + checkArgument(columnHandle.getDataFormat() == null, "dataFormat must be null"); + + checkArgument(isSupportedType(columnHandle.getType()), "Unsupported column type '%s' for column '%s'", columnHandle.getType(), columnHandle.getName()); + } + this.descriptor = requireNonNull(descriptor, "descriptor is null"); + this.messageBuilder = DynamicMessage.newBuilder(this.descriptor); + } + + private boolean isSupportedType(Type type) + { + if (isSupportedPrimitive(type)) { + return true; + } + + if (type instanceof ArrayType) { + checkArgument(type.getTypeParameters().size() == 1, "expecting exactly one type parameter for array"); + return isSupportedType(type.getTypeParameters().get(0)); + } + + if (type instanceof MapType) { + List typeParameters = type.getTypeParameters(); + checkArgument(typeParameters.size() == 2, "expecting exactly two type parameters for map"); + return isSupportedType(typeParameters.get(0)) && isSupportedType(typeParameters.get(1)); + } + + if (type instanceof RowType) { + checkArgument(((RowType) type).getFields().stream().allMatch(field -> field.getName().isPresent()), "expecting name for field in rows"); + for (Type fieldType : type.getTypeParameters()) { + if (!isSupportedType(fieldType)) { + return false; + } + } + return true; + } + return false; + } + + private boolean isSupportedPrimitive(Type type) + { + return (type instanceof TimestampType && ((TimestampType) type).isShort()) || + type instanceof VarcharType || + type instanceof VarbinaryType || + SUPPORTED_PRIMITIVE_TYPES.contains(type); + } + + @Override + protected void appendNullValue() + { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Protobuf doesn't support serializing null values"); + } + + @Override + protected void appendLong(long value) + { + append(value); + } + + @Override + protected void appendInt(int value) + { + append(value); + } + + @Override + protected void appendShort(short value) + { + append(value); + } + + @Override + protected void appendDouble(double value) + { + append(value); + } + + @Override + protected void appendFloat(float value) + { + append(value); + } + + @Override + protected void appendByte(byte value) + { + append(value); + } + + @Override + protected void appendBoolean(boolean value) + { + append(value); + } + + @Override + protected void appendString(String value) + { + append(value); + } + + @Override + protected void appendByteBuffer(ByteBuffer value) + { + append(value); + } + + @Override + protected void appendArray(List value) + { + append(value); + } + + @Override + protected void appendSqlTimestamp(SqlTimestamp value) + { + append(value); + } + + @Override + protected void appendMap(Map value) + { + append(value); + } + + @Override + protected void appendRow(List value) + { + append(value); + } + + @Override + public byte[] toByteArray() + { + resetColumnIndex(); + try { + return messageBuilder.build().toByteArray(); + } + finally { + messageBuilder.clear(); + } + } + + private void append(Object value) + { + setField(descriptor, messageBuilder, columnHandles.get(currentColumnIndex).getType(), columnHandles.get(currentColumnIndex).getMapping(), value); + } + + private DynamicMessage setField(Descriptor descriptor, DynamicMessage.Builder messageBuilder, Type type, String columnMapping, Object value) + { + List columnPath = Splitter.on("/") + .omitEmptyStrings() + .limit(2) + .splitToList(columnMapping); + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(columnPath.get(0)); + checkState(fieldDescriptor != null, format("Unknown Field %s", columnPath.get(0))); + if (columnPath.size() == 2) { + checkState(fieldDescriptor.getJavaType() == MESSAGE, "Expected MESSAGE type, but got: %s", fieldDescriptor.getJavaType()); + value = setField( + fieldDescriptor.getMessageType(), + DynamicMessage.newBuilder((DynamicMessage) messageBuilder.getField(fieldDescriptor)), + type, + columnPath.get(1), + value); + } + else { + value = encodeObject(fieldDescriptor, type, value); + } + setField(fieldDescriptor, messageBuilder, value); + return messageBuilder.build(); + } + + private Object encodeObject(FieldDescriptor fieldDescriptor, Type type, Object value) + { + if (value == null) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, "Protobuf doesn't support serializing null values"); + } + if (type instanceof VarbinaryType) { + if (value instanceof SqlVarbinary) { + return ByteString.copyFrom(((SqlVarbinary) value).getBytes()); + } + if (value instanceof ByteBuffer) { + ByteBuffer byteBuffer = (ByteBuffer) value; + return ByteString.copyFrom(byteBuffer, byteBuffer.limit()); + } + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("cannot decode object of '%s' as '%s'", value.getClass(), type)); + } + if (type instanceof TimestampType) { + checkArgument(value instanceof SqlTimestamp, "value should be an instance of SqlTimestamp"); + return encodeTimestamp((SqlTimestamp) value); + } + if (type instanceof ArrayType) { + checkArgument(value instanceof List, "value should be an instance of List"); + return encodeArray(fieldDescriptor, type, (List) value); + } + if (type instanceof MapType) { + checkArgument(value instanceof Map, "value should be an instance of Map"); + return encodeMap(fieldDescriptor, type, (Map) value); + } + if (type instanceof RowType) { + checkArgument(value instanceof List, "value should be an instance of List"); + return encodeRow(fieldDescriptor, type, (List) value); + } + return value; + } + + private Timestamp encodeTimestamp(SqlTimestamp timestamp) + { + int nanos = floorMod(timestamp.getEpochMicros(), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND; + try { + return checkValid(Timestamp.newBuilder() + .setSeconds(floorDiv(timestamp.getEpochMicros(), MICROSECONDS_PER_SECOND)) + .setNanos(nanos) + .build()); + } + catch (IllegalArgumentException e) { + throw new TrinoException(INVALID_TIMESTAMP, e.getMessage()); + } + } + + private List encodeArray(FieldDescriptor fieldDescriptor, Type type, List value) + { + return value.stream() + .map(entry -> encodeObject(fieldDescriptor, type.getTypeParameters().get(0), entry)) + .collect(toImmutableList()); + } + + private List encodeMap(FieldDescriptor fieldDescriptor, Type type, Map value) + { + Descriptor descriptor = fieldDescriptor.getMessageType(); + ImmutableList.Builder dynamicMessageListBuilder = ImmutableList.builder(); + for (Map.Entry entry : value.entrySet()) { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + setField( + descriptor.findFieldByNumber(1), + builder, + encodeObject( + descriptor.findFieldByNumber(1), + type.getTypeParameters().get(0), + entry.getKey())); + setField( + descriptor.findFieldByNumber(2), + builder, + encodeObject( + descriptor.findFieldByNumber(2), + type.getTypeParameters().get(1), + entry.getValue())); + dynamicMessageListBuilder.add(builder.build()); + } + return dynamicMessageListBuilder.build(); + } + + private DynamicMessage encodeRow(FieldDescriptor fieldDescriptor, Type type, List value) + { + Descriptor descriptor = fieldDescriptor.getMessageType(); + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + RowType rowType = (RowType) type; + int index = 0; + for (RowType.Field field : rowType.getFields()) { + checkArgument(field.getName().isPresent(), "FieldName is absent"); + setField( + descriptor.findFieldByName(field.getName().get()), + builder, + encodeObject( + descriptor.findFieldByName(field.getName().get()), + field.getType(), + value.get(index))); + index++; + } + return builder.build(); + } + + private void setField(FieldDescriptor fieldDescriptor, DynamicMessage.Builder builder, Object value) + { + if (fieldDescriptor.getJavaType() == ENUM) { + value = fieldDescriptor.getEnumType().findValueByName((String) value); + } + builder.setField(fieldDescriptor, value); + } +} diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoderFactory.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoderFactory.java new file mode 100644 index 000000000000..894b510ecfe8 --- /dev/null +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufRowEncoderFactory.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.encoder.protobuf; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.DescriptorValidationException; +import io.trino.plugin.kafka.encoder.EncoderColumnHandle; +import io.trino.plugin.kafka.encoder.RowEncoder; +import io.trino.plugin.kafka.encoder.RowEncoderFactory; +import io.trino.spi.TrinoException; +import io.trino.spi.connector.ConnectorSession; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.trino.decoder.protobuf.ProtobufErrorCode.INVALID_PROTO_FILE; +import static io.trino.decoder.protobuf.ProtobufErrorCode.MESSAGE_NOT_FOUND; +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static io.trino.decoder.protobuf.ProtobufUtils.getFileDescriptor; +import static java.lang.String.format; + +public class ProtobufRowEncoderFactory + implements RowEncoderFactory +{ + @Override + public RowEncoder create(ConnectorSession session, Optional dataSchema, List columnHandles) + { + checkArgument(dataSchema.isPresent(), "dataSchema for Protobuf format is not present"); + + try { + Descriptor descriptor = getFileDescriptor(dataSchema.get()).findMessageTypeByName(DEFAULT_MESSAGE); + if (descriptor != null) { + return new ProtobufRowEncoder(descriptor, session, columnHandles); + } + } + catch (DescriptorValidationException descriptorValidationException) { + throw new TrinoException(INVALID_PROTO_FILE, "Unable to parse protobuf schema", descriptorValidationException); + } + throw new TrinoException(MESSAGE_NOT_FOUND, format("Message %s not found", DEFAULT_MESSAGE)); + } +} diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java new file mode 100644 index 000000000000..d9b3e808c5a9 --- /dev/null +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java @@ -0,0 +1,115 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.encoder.protobuf; + +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import io.confluent.kafka.schemaregistry.ParsedSchema; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema; +import io.trino.decoder.protobuf.ProtobufRowDecoder; +import io.trino.plugin.kafka.KafkaTopicFieldDescription; +import io.trino.plugin.kafka.KafkaTopicFieldGroup; +import io.trino.plugin.kafka.schema.confluent.SchemaParser; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.Type; +import io.trino.spi.type.TypeManager; + +import javax.inject.Inject; + +import java.util.Optional; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static java.util.Objects.requireNonNull; + +public class ProtobufSchemaParser + implements SchemaParser +{ + private static final String TIMESTAMP_TYPE_NAME = "google.protobuf.Timestamp"; + private final TypeManager typeManager; + + @Inject + public ProtobufSchemaParser(TypeManager typeManager) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + + @Override + public KafkaTopicFieldGroup parse(ConnectorSession session, String subject, ParsedSchema parsedSchema) + { + ProtobufSchema protobufSchema = (ProtobufSchema) parsedSchema; + return new KafkaTopicFieldGroup( + ProtobufRowDecoder.NAME, + Optional.empty(), + Optional.of(subject), + protobufSchema.toDescriptor().getFields().stream() + .map(field -> new KafkaTopicFieldDescription( + field.getName(), + getType(field), + field.getName(), + null, + null, + null, + false)) + .collect(toImmutableList())); + } + + private Type getType(FieldDescriptor fieldDescriptor) + { + Type baseType = switch (fieldDescriptor.getJavaType()) { + case BOOLEAN -> BOOLEAN; + case INT -> INTEGER; + case LONG -> BIGINT; + case FLOAT -> REAL; + case DOUBLE -> DOUBLE; + case BYTE_STRING -> VARBINARY; + case STRING, ENUM -> createUnboundedVarcharType(); + case MESSAGE -> getTypeForMessage(fieldDescriptor); + }; + + // Protobuf does not support adding repeated label for map type but schema registry incorrecty adds it + if (fieldDescriptor.isRepeated() && !fieldDescriptor.isMapField()) { + return new ArrayType(baseType); + } + return baseType; + } + + private Type getTypeForMessage(FieldDescriptor fieldDescriptor) + { + Descriptor descriptor = fieldDescriptor.getMessageType(); + if (fieldDescriptor.getMessageType().getFullName().equals(TIMESTAMP_TYPE_NAME)) { + return createTimestampType(6); + } + if (fieldDescriptor.isMapField()) { + return new MapType( + getType(descriptor.findFieldByNumber(1)), + getType(descriptor.findFieldByNumber(2)), + typeManager.getTypeOperators()); + } + return RowType.from( + fieldDescriptor.getMessageType().getFields().stream() + .map(field -> RowType.field(field.getName(), getType(field))) + .collect(toImmutableList())); + } +} diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java index 63805e3deb8b..4caa15575be8 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentModule.java @@ -13,17 +13,22 @@ */ package io.trino.plugin.kafka.schema.confluent; +import com.google.common.base.Suppliers; import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; import com.google.inject.Binder; import com.google.inject.Module; import com.google.inject.Provides; import com.google.inject.Scopes; import com.google.inject.multibindings.MapBinder; import io.airlift.configuration.AbstractConfigurationAwareModule; +import io.confluent.kafka.schemaregistry.ParsedSchema; import io.confluent.kafka.schemaregistry.SchemaProvider; import io.confluent.kafka.schemaregistry.avro.AvroSchemaProvider; import io.confluent.kafka.schemaregistry.client.CachedSchemaRegistryClient; import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; +import io.confluent.kafka.schemaregistry.client.rest.entities.SchemaReference; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchemaProvider; import io.trino.decoder.DispatchingRowDecoderFactory; import io.trino.decoder.RowDecoderFactory; import io.trino.decoder.avro.AvroBytesDeserializer; @@ -32,10 +37,15 @@ import io.trino.decoder.avro.AvroRowDecoderFactory; import io.trino.decoder.dummy.DummyRowDecoder; import io.trino.decoder.dummy.DummyRowDecoderFactory; +import io.trino.decoder.protobuf.DynamicMessageProvider; +import io.trino.decoder.protobuf.ProtobufRowDecoder; +import io.trino.decoder.protobuf.ProtobufRowDecoderFactory; import io.trino.plugin.base.session.SessionPropertiesProvider; import io.trino.plugin.kafka.encoder.DispatchingRowEncoderFactory; import io.trino.plugin.kafka.encoder.RowEncoderFactory; import io.trino.plugin.kafka.encoder.avro.AvroRowEncoder; +import io.trino.plugin.kafka.encoder.protobuf.ProtobufRowEncoder; +import io.trino.plugin.kafka.encoder.protobuf.ProtobufSchemaParser; import io.trino.plugin.kafka.schema.ContentSchemaReader; import io.trino.plugin.kafka.schema.TableDescriptionSupplier; import io.trino.spi.HostAddress; @@ -45,8 +55,12 @@ import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; +import static com.google.common.base.Preconditions.checkState; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static com.google.inject.Scopes.SINGLETON; @@ -69,9 +83,13 @@ protected void setup(Binder binder) binder.bind(ContentSchemaReader.class).to(AvroConfluentContentSchemaReader.class).in(Scopes.SINGLETON); newSetBinder(binder, SchemaRegistryClientPropertiesProvider.class); newSetBinder(binder, SchemaProvider.class).addBinding().to(AvroSchemaProvider.class).in(Scopes.SINGLETON); + // Each SchemaRegistry object should have a new instance of SchemaProvider + newSetBinder(binder, SchemaProvider.class).addBinding().to(LazyLoadedProtobufSchemaProvider.class); + binder.bind(DynamicMessageProvider.Factory.class).to(ConfluentSchemaRegistryDynamicMessageProvider.Factory.class).in(SINGLETON); newSetBinder(binder, SessionPropertiesProvider.class).addBinding().to(ConfluentSessionProperties.class).in(Scopes.SINGLETON); binder.bind(TableDescriptionSupplier.class).toProvider(ConfluentSchemaRegistryTableDescriptionSupplier.Factory.class).in(Scopes.SINGLETON); newMapBinder(binder, String.class, SchemaParser.class).addBinding("AVRO").to(AvroSchemaParser.class).in(Scopes.SINGLETON); + newMapBinder(binder, String.class, SchemaParser.class).addBinding("PROTOBUF").to(ProtobufSchemaParser.class).in(Scopes.SINGLETON); } @Provides @@ -112,6 +130,7 @@ public void configure(Binder binder) binder.bind(AvroReaderSupplier.Factory.class).to(ConfluentAvroReaderSupplier.Factory.class).in(Scopes.SINGLETON); binder.bind(AvroDeserializer.Factory.class).to(AvroBytesDeserializer.Factory.class).in(Scopes.SINGLETON); newMapBinder(binder, String.class, RowDecoderFactory.class).addBinding(AvroRowDecoderFactory.NAME).to(AvroRowDecoderFactory.class).in(Scopes.SINGLETON); + newMapBinder(binder, String.class, RowDecoderFactory.class).addBinding(ProtobufRowDecoder.NAME).to(ProtobufRowDecoderFactory.class).in(Scopes.SINGLETON); newMapBinder(binder, String.class, RowDecoderFactory.class).addBinding(DummyRowDecoder.NAME).to(DummyRowDecoderFactory.class).in(SINGLETON); binder.bind(DispatchingRowDecoderFactory.class).in(SINGLETON); } @@ -127,7 +146,47 @@ public void configure(Binder binder) encoderFactoriesByName.addBinding(AvroRowEncoder.NAME).toInstance((session, dataSchema, columnHandles) -> { throw new TrinoException(NOT_SUPPORTED, "Insert not supported"); }); + encoderFactoriesByName.addBinding(ProtobufRowEncoder.NAME).toInstance((session, dataSchema, columnHandles) -> { + throw new TrinoException(NOT_SUPPORTED, "Insert is not supported for schema registry based tables"); + }); binder.bind(DispatchingRowEncoderFactory.class).in(SINGLETON); } } + + private static class LazyLoadedProtobufSchemaProvider + implements SchemaProvider + { + // Make JVM to load lazily ProtobufSchemaProvider, so Kafka connector can be used + // with protobuf dependency for non protobuf based topics + private final Supplier delegate = Suppliers.memoize(this::create); + private final AtomicReference> configuration = new AtomicReference<>(); + + @Override + public String schemaType() + { + return "PROTOBUF"; + } + + @Override + public void configure(Map configuration) + { + Map oldConfiguration = this.configuration.getAndSet(ImmutableMap.copyOf(configuration)); + checkState(oldConfiguration == null, "ProtobufSchemaProvider is already configured"); + } + + @Override + public Optional parseSchema(String schema, List references) + { + return delegate.get().parseSchema(schema, references); + } + + private SchemaProvider create() + { + ProtobufSchemaProvider schemaProvider = new ProtobufSchemaProvider(); + Map configuration = this.configuration.get(); + checkState(configuration != null, "ProtobufSchemaProvider is not already configured"); + schemaProvider.configure(configuration); + return schemaProvider; + } + } } diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentSchemaRegistryDynamicMessageProvider.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentSchemaRegistryDynamicMessageProvider.java new file mode 100644 index 000000000000..116417de80a1 --- /dev/null +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/schema/confluent/ConfluentSchemaRegistryDynamicMessageProvider.java @@ -0,0 +1,105 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.schema.confluent; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; +import io.confluent.kafka.schemaregistry.ParsedSchema; +import io.confluent.kafka.schemaregistry.client.SchemaRegistryClient; +import io.confluent.kafka.schemaregistry.client.rest.exceptions.RestClientException; +import io.confluent.kafka.schemaregistry.protobuf.MessageIndexes; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema; +import io.trino.collect.cache.NonEvictableLoadingCache; +import io.trino.decoder.protobuf.DynamicMessageProvider; +import io.trino.spi.TrinoException; + +import javax.inject.Inject; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.Optional; + +import static com.google.common.base.Preconditions.checkArgument; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.collect.cache.SafeCaches.buildNonEvictableCache; +import static io.trino.decoder.protobuf.ProtobufErrorCode.INVALID_PROTOBUF_MESSAGE; +import static io.trino.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ConfluentSchemaRegistryDynamicMessageProvider + implements DynamicMessageProvider +{ + private static final int MAGIC_BYTE = 0; + private final SchemaRegistryClient schemaRegistryClient; + private final NonEvictableLoadingCache descriptorCache; + + public ConfluentSchemaRegistryDynamicMessageProvider(SchemaRegistryClient schemaRegistryClient) + { + this.schemaRegistryClient = requireNonNull(schemaRegistryClient, "schemaRegistryClient is null"); + descriptorCache = buildNonEvictableCache( + CacheBuilder.newBuilder().maximumSize(1000), + CacheLoader.from(this::lookupDescriptor)); + } + + @Override + public DynamicMessage parseDynamicMessage(byte[] data) + { + ByteBuffer buffer = ByteBuffer.wrap(data); + byte magicByte = buffer.get(); + checkArgument(magicByte == MAGIC_BYTE, "Invalid MagicByte"); + int schemaId = buffer.getInt(); + MessageIndexes.readFrom(buffer); + try { + return DynamicMessage.parseFrom( + descriptorCache.getUnchecked(schemaId), + wrappedBuffer(buffer).getInput()); + } + catch (IOException e) { + throw new TrinoException(INVALID_PROTOBUF_MESSAGE, "Decoding Protobuf record failed.", e); + } + } + + private Descriptor lookupDescriptor(int schemaId) + { + try { + ParsedSchema schema = schemaRegistryClient.getSchemaById(schemaId); + checkArgument(schema instanceof ProtobufSchema, "schema should be an instance of ProtobufSchema"); + return ((ProtobufSchema) schema).toDescriptor(); + } + catch (IOException | RestClientException e) { + throw new TrinoException(GENERIC_INTERNAL_ERROR, format("Looking up schemaId '%s'from confluent schema registry failed", schemaId), e); + } + } + + public static class Factory + implements DynamicMessageProvider.Factory + { + private final SchemaRegistryClient schemaRegistryClient; + + @Inject + public Factory(SchemaRegistryClient schemaRegistryClient) + { + this.schemaRegistryClient = requireNonNull(schemaRegistryClient, "schemaRegistryClient is null"); + } + + @Override + public DynamicMessageProvider create(Optional protoFile) + { + return new ConfluentSchemaRegistryDynamicMessageProvider(schemaRegistryClient); + } + } +} diff --git a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java new file mode 100644 index 000000000000..9b444098e29e --- /dev/null +++ b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java @@ -0,0 +1,324 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.protobuf; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.confluent.kafka.serializers.protobuf.KafkaProtobufSerializer; +import io.confluent.kafka.serializers.subject.RecordNameStrategy; +import io.confluent.kafka.serializers.subject.TopicRecordNameStrategy; +import io.trino.plugin.kafka.schema.confluent.KafkaWithConfluentSchemaRegistryQueryRunner; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.kafka.TestingKafka; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; +import org.apache.kafka.clients.producer.ProducerRecord; +import org.testng.annotations.Test; + +import java.time.Duration; +import java.util.List; +import java.util.Map; + +import static com.google.protobuf.Descriptors.FieldDescriptor.JavaType.ENUM; +import static com.google.protobuf.Descriptors.FieldDescriptor.JavaType.STRING; +import static io.confluent.kafka.serializers.AbstractKafkaAvroSerDeConfig.SCHEMA_REGISTRY_URL_CONFIG; +import static io.confluent.kafka.serializers.AbstractKafkaAvroSerDeConfig.VALUE_SUBJECT_NAME_STRATEGY; +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static io.trino.decoder.protobuf.ProtobufUtils.getFileDescriptor; +import static io.trino.decoder.protobuf.ProtobufUtils.getProtoFile; +import static java.lang.Math.multiplyExact; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static org.apache.kafka.clients.producer.ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG; +import static org.apache.kafka.clients.producer.ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestKafkaProtobufWithSchemaRegistryMinimalFunctionality + extends AbstractTestQueryFramework +{ + private static final String RECORD_NAME = "schema"; + private static final int MESSAGE_COUNT = 100; + + private TestingKafka testingKafka; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + testingKafka = closeAfterClass(TestingKafka.createWithSchemaRegistry()); + return KafkaWithConfluentSchemaRegistryQueryRunner.builder(testingKafka).build(); + } + + @Test + public void testBasicTopic() + throws Exception + { + String topic = "topic-basic-MixedCase"; + assertTopic( + topic, + format("SELECT col_1, col_2 FROM %s", toDoubleQuoted(topic)), + format("SELECT col_1, col_2, col_3 FROM %s", toDoubleQuoted(topic)), + false, + producerProperties()); + } + + @Test + public void testTopicWithKeySubject() + throws Exception + { + String topic = "topic-Key-Subject"; + assertTopic( + topic, + format("SELECT key, col_1, col_2 FROM %s", toDoubleQuoted(topic)), + format("SELECT key, col_1, col_2, col_3 FROM %s", toDoubleQuoted(topic)), + true, + producerProperties()); + } + + @Test + public void testTopicWithRecordNameStrategy() + throws Exception + { + String topic = "topic-Record-Name-Strategy"; + assertTopic( + topic, + format("SELECT key, col_1, col_2 FROM \"%1$s&value-subject=%2$s\"", topic, RECORD_NAME), + format("SELECT key, col_1, col_2, col_3 FROM \"%1$s&value-subject=%2$s\"", topic, RECORD_NAME), + true, + ImmutableMap.builder() + .putAll(producerProperties()) + .put(VALUE_SUBJECT_NAME_STRATEGY, RecordNameStrategy.class.getName()) + .buildOrThrow()); + } + + @Test + public void testTopicWithTopicRecordNameStrategy() + throws Exception + { + String topic = "topic-Topic-Record-Name-Strategy"; + assertTopic( + topic, + format("SELECT key, col_1, col_2 FROM \"%1$s&value-subject=%1$s-%2$s\"", topic, RECORD_NAME), + format("SELECT key, col_1, col_2, col_3 FROM \"%1$s&value-subject=%1$s-%2$s\"", topic, RECORD_NAME), + true, + ImmutableMap.builder() + .putAll(producerProperties()) + .put(VALUE_SUBJECT_NAME_STRATEGY, TopicRecordNameStrategy.class.getName()) + .buildOrThrow()); + } + + @Test + public void testBasicTopicForInsert() + throws Exception + { + String topic = "topic-basic-inserts"; + assertTopic( + topic, + format("SELECT col_1, col_2 FROM %s", toDoubleQuoted(topic)), + format("SELECT col_1, col_2, col_3 FROM %s", toDoubleQuoted(topic)), + false, + producerProperties()); + assertQueryFails( + format("INSERT INTO %s (col_1, col_2, col_3) VALUES ('Trino', 14, 1.4)", toDoubleQuoted(topic)), + "Insert is not supported for schema registry based tables"); + } + + private Map producerProperties() + { + return ImmutableMap.of( + SCHEMA_REGISTRY_URL_CONFIG, testingKafka.getSchemaRegistryConnectString(), + KEY_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName(), + VALUE_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName()); + } + + private void assertTopic(String topicName, String initialQuery, String evolvedQuery, boolean isKeyIncluded, Map producerConfig) + throws Exception + { + testingKafka.createTopic(topicName); + + assertNotExists(topicName); + + List> messages = createMessages(topicName, MESSAGE_COUNT, true, getInitialSchema(), getKeySchema()); + testingKafka.sendMessages(messages.stream(), producerConfig); + + waitUntilTableExists(topicName); + assertCount(topicName, MESSAGE_COUNT); + + assertQuery(initialQuery, getExpectedValues(messages, getInitialSchema(), isKeyIncluded)); + + List> newMessages = createMessages(topicName, MESSAGE_COUNT, false, getEvolvedSchema(), getKeySchema()); + testingKafka.sendMessages(newMessages.stream(), producerConfig); + + List> allMessages = ImmutableList.>builder() + .addAll(messages) + .addAll(newMessages) + .build(); + assertCount(topicName, allMessages.size()); + assertQuery(evolvedQuery, getExpectedValues(allMessages, getEvolvedSchema(), isKeyIncluded)); + } + + private static String getExpectedValues(List> messages, Descriptor descriptor, boolean isKeyIncluded) + { + StringBuilder valuesBuilder = new StringBuilder("VALUES "); + ImmutableList.Builder rowsBuilder = ImmutableList.builder(); + for (ProducerRecord message : messages) { + ImmutableList.Builder columnsBuilder = ImmutableList.builder(); + + if (isKeyIncluded) { + addExpectedColumns(message.key().getDescriptorForType(), message.key(), columnsBuilder); + } + + addExpectedColumns(descriptor, message.value(), columnsBuilder); + + rowsBuilder.add(format("(%s)", String.join(", ", columnsBuilder.build()))); + } + valuesBuilder.append(String.join(", ", rowsBuilder.build())); + return valuesBuilder.toString(); + } + + private static void addExpectedColumns(Descriptor descriptor, DynamicMessage message, ImmutableList.Builder columnsBuilder) + { + for (FieldDescriptor field : descriptor.getFields()) { + FieldDescriptor fieldDescriptor = message.getDescriptorForType().findFieldByName(field.getName()); + if (fieldDescriptor == null) { + columnsBuilder.add("null"); + continue; + } + Object value = message.getField(message.getDescriptorForType().findFieldByName(field.getName())); + if (field.getJavaType() == STRING || field.getJavaType() == ENUM) { + columnsBuilder.add(toSingleQuoted(value)); + } + else { + columnsBuilder.add(String.valueOf(value)); + } + } + } + + private void assertNotExists(String tableName) + { + if (schemaExists()) { + assertQueryReturnsEmptyResult(format("SHOW TABLES LIKE '%s'", tableName)); + } + } + + private void waitUntilTableExists(String tableName) + { + Failsafe.with( + new RetryPolicy<>() + .withMaxAttempts(10) + .withDelay(Duration.ofMillis(100))) + .run(() -> assertTrue(schemaExists())); + Failsafe.with( + new RetryPolicy<>() + .withMaxAttempts(10) + .withDelay(Duration.ofMillis(100))) + .run(() -> assertTrue(tableExists(tableName))); + } + + private boolean schemaExists() + { + return computeActual(format("SHOW SCHEMAS FROM %s LIKE '%s'", getSession().getCatalog().get(), getSession().getSchema().get())).getRowCount() == 1; + } + + private boolean tableExists(String tableName) + { + return computeActual(format("SHOW TABLES LIKE '%s'", tableName.toLowerCase(ENGLISH))).getRowCount() == 1; + } + + private void assertCount(String tableName, int count) + { + assertQuery(format("SELECT count(*) FROM %s", toDoubleQuoted(tableName)), format("VALUES (%s)", count)); + } + + private static Descriptor getInitialSchema() + throws Exception + { + return getDescriptor("initial_schema.proto"); + } + + private static Descriptor getEvolvedSchema() + throws Exception + { + return getDescriptor("evolved_schema.proto"); + } + + private static Descriptor getKeySchema() + throws Exception + { + return getDescriptor("key_schema.proto"); + } + + public static Descriptor getDescriptor(String fileName) + throws Exception + { + return getFileDescriptor(getProtoFile("protobuf/" + fileName)).findMessageTypeByName(DEFAULT_MESSAGE); + } + + private static String toDoubleQuoted(String tableName) + { + return format("\"%s\"", tableName); + } + + private static String toSingleQuoted(Object value) + { + requireNonNull(value, "value is null"); + return format("'%s'", value); + } + + private static List> createMessages(String topicName, int messageCount, boolean useInitialSchema, Descriptor descriptor, Descriptor keyDescriptor) + { + ImmutableList.Builder> producerRecordBuilder = ImmutableList.builder(); + if (useInitialSchema) { + for (long key = 0; key < messageCount; key++) { + producerRecordBuilder.add(new ProducerRecord<>(topicName, createKeySchema(key, keyDescriptor), createRecordWithInitialSchema(key, descriptor))); + } + } + else { + for (long key = 0; key < messageCount; key++) { + producerRecordBuilder.add(new ProducerRecord<>(topicName, createKeySchema(key, keyDescriptor), createRecordWithEvolvedSchema(key, descriptor))); + } + } + return producerRecordBuilder.build(); + } + + private static DynamicMessage createKeySchema(long key, Descriptor descriptor) + { + return DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("key"), key) + .build(); + } + + private static DynamicMessage createRecordWithInitialSchema(long key, Descriptor descriptor) + { + return DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("col_1"), format("string-%s", key)) + .setField(descriptor.findFieldByName("col_2"), multiplyExact(key, 100)) + .build(); + } + + private static DynamicMessage createRecordWithEvolvedSchema(long key, Descriptor descriptor) + { + return DynamicMessage.newBuilder(descriptor) + .setField(descriptor.findFieldByName("col_1"), format("string-%s", key)) + .setField(descriptor.findFieldByName("col_2"), multiplyExact(key, 100)) + .setField(descriptor.findFieldByName("col_3"), (key + 10.1D) / 10.0D) + .build(); + } +} diff --git a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestProtobufEncoder.java b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestProtobufEncoder.java new file mode 100644 index 000000000000..4ff338361d66 --- /dev/null +++ b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestProtobufEncoder.java @@ -0,0 +1,368 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.kafka.protobuf; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; +import io.airlift.slice.Slices; +import io.trino.decoder.protobuf.ProtobufDataProviders; +import io.trino.plugin.kafka.KafkaColumnHandle; +import io.trino.plugin.kafka.encoder.EncoderColumnHandle; +import io.trino.plugin.kafka.encoder.RowEncoder; +import io.trino.plugin.kafka.encoder.protobuf.ProtobufRowEncoderFactory; +import io.trino.spi.block.Block; +import io.trino.spi.block.BlockBuilder; +import io.trino.spi.type.ArrayType; +import io.trino.spi.type.MapType; +import io.trino.spi.type.RowType; +import io.trino.spi.type.SqlTimestamp; +import io.trino.spi.type.Type; +import io.trino.testing.TestingConnectorSession; +import org.testng.annotations.Test; + +import java.util.List; +import java.util.Optional; + +import static io.airlift.slice.Slices.utf8Slice; +import static io.airlift.slice.Slices.wrappedBuffer; +import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; +import static io.trino.decoder.protobuf.ProtobufUtils.getFileDescriptor; +import static io.trino.decoder.protobuf.ProtobufUtils.getProtoFile; +import static io.trino.spi.block.ArrayBlock.fromElementBlock; +import static io.trino.spi.block.RowBlock.fromFieldBlocks; +import static io.trino.spi.predicate.Utils.nativeValueToBlock; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.TypeSignature.mapType; +import static io.trino.spi.type.TypeUtils.writeNativeValue; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.type.InternalTypeManager.TESTING_TYPE_MANAGER; +import static java.lang.Float.floatToIntBits; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; +import static org.testng.Assert.assertEquals; + +public class TestProtobufEncoder +{ + private static final ProtobufRowEncoderFactory ENCODER_FACTORY = new ProtobufRowEncoderFactory(); + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testAllDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + Descriptor descriptor = getDescriptor("all_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + messageBuilder.setField(descriptor.findFieldByName("stringColumn"), stringData); + messageBuilder.setField(descriptor.findFieldByName("integerColumn"), integerData); + messageBuilder.setField(descriptor.findFieldByName("longColumn"), longData); + messageBuilder.setField(descriptor.findFieldByName("doubleColumn"), doubleData); + messageBuilder.setField(descriptor.findFieldByName("floatColumn"), floatData); + messageBuilder.setField(descriptor.findFieldByName("booleanColumn"), booleanData); + messageBuilder.setField(descriptor.findFieldByName("numberColumn"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + messageBuilder.setField(descriptor.findFieldByName("timestampColumn"), getTimestamp(sqlTimestamp)); + messageBuilder.setField(descriptor.findFieldByName("bytesColumn"), bytesData); + + RowEncoder rowEncoder = createRowEncoder( + "all_datatypes.proto", + ImmutableList.of( + createEncoderColumnHandle("stringColumn", createVarcharType(100), "stringColumn"), + createEncoderColumnHandle("integerColumn", INTEGER, "integerColumn"), + createEncoderColumnHandle("longColumn", BIGINT, "longColumn"), + createEncoderColumnHandle("doubleColumn", DOUBLE, "doubleColumn"), + createEncoderColumnHandle("floatColumn", REAL, "floatColumn"), + createEncoderColumnHandle("booleanColumn", BOOLEAN, "booleanColumn"), + createEncoderColumnHandle("numberColumn", createVarcharType(4), "numberColumn"), + createEncoderColumnHandle("timestampColumn", createTimestampType(6), "timestampColumn"), + createEncoderColumnHandle("bytesColumn", VARBINARY, "bytesColumn"))); + + rowEncoder.appendColumnValue(nativeValueToBlock(createVarcharType(5), utf8Slice(stringData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(INTEGER, integerData.longValue()), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(BIGINT, longData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(DOUBLE, doubleData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(REAL, (long) floatToIntBits(floatData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(BOOLEAN, booleanData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(createVarcharType(4), utf8Slice(enumData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(createTimestampType(6), sqlTimestamp.getEpochMicros()), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(VARBINARY, wrappedBuffer(bytesData)), 0); + + assertEquals(messageBuilder.build().toByteArray(), rowEncoder.toByteArray()); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testStructuralDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + messageBuilder.setField(descriptor.findFieldByName("list"), ImmutableList.of(stringData)); + + Descriptor mapDescriptor = descriptor.findFieldByName("map").getMessageType(); + DynamicMessage.Builder mapBuilder = DynamicMessage.newBuilder(mapDescriptor); + mapBuilder.setField(mapDescriptor.findFieldByName("key"), "Key"); + mapBuilder.setField(mapDescriptor.findFieldByName("value"), "Value"); + messageBuilder.setField(descriptor.findFieldByName("map"), ImmutableList.of(mapBuilder.build())); + + Descriptor rowDescriptor = descriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + List columnHandles = ImmutableList.of( + createEncoderColumnHandle("list", new ArrayType(createVarcharType(30000)), "list"), + createEncoderColumnHandle("map", TESTING_TYPE_MANAGER.getType(mapType(VARCHAR.getTypeSignature(), VARCHAR.getTypeSignature())), "map"), + createEncoderColumnHandle( + "row", + RowType.from(ImmutableList.builder() + .add(RowType.field("string_column", createVarcharType(30000))) + .add(RowType.field("integer_column", INTEGER)) + .add(RowType.field("long_column", BIGINT)) + .add(RowType.field("double_column", DOUBLE)) + .add(RowType.field("float_column", REAL)) + .add(RowType.field("boolean_column", BOOLEAN)) + .add(RowType.field("number_column", createVarcharType(4))) + .add(RowType.field("timestamp_column", createTimestampType(6))) + .add(RowType.field("bytes_column", VARBINARY)) + .build()), + "row")); + + RowEncoder rowEncoder = createRowEncoder("structural_datatypes.proto", columnHandles.subList(0, 3)); + + BlockBuilder arrayBlockBuilder = columnHandles.get(0).getType() + .createBlockBuilder(null, 1); + BlockBuilder singleArrayBlockWriter = arrayBlockBuilder.beginBlockEntry(); + writeNativeValue(createVarcharType(5), singleArrayBlockWriter, utf8Slice(stringData)); + arrayBlockBuilder.closeEntry(); + rowEncoder.appendColumnValue(arrayBlockBuilder.build(), 0); + + BlockBuilder mapBlockBuilder = columnHandles.get(1).getType() + .createBlockBuilder(null, 1); + BlockBuilder singleMapBlockWriter = mapBlockBuilder.beginBlockEntry(); + writeNativeValue(VARCHAR, singleMapBlockWriter, utf8Slice("Key")); + writeNativeValue(VARCHAR, singleMapBlockWriter, utf8Slice("Value")); + mapBlockBuilder.closeEntry(); + rowEncoder.appendColumnValue(mapBlockBuilder.build(), 0); + + BlockBuilder rowBlockBuilder = columnHandles.get(2).getType() + .createBlockBuilder(null, 1); + BlockBuilder singleRowBlockWriter = rowBlockBuilder.beginBlockEntry(); + writeNativeValue(VARCHAR, singleRowBlockWriter, Slices.utf8Slice(stringData)); + writeNativeValue(INTEGER, singleRowBlockWriter, integerData.longValue()); + writeNativeValue(BIGINT, singleRowBlockWriter, longData); + writeNativeValue(DOUBLE, singleRowBlockWriter, doubleData); + writeNativeValue(REAL, singleRowBlockWriter, (long) floatToIntBits(floatData)); + writeNativeValue(BOOLEAN, singleRowBlockWriter, booleanData); + writeNativeValue(VARCHAR, singleRowBlockWriter, enumData); + writeNativeValue(createTimestampType(6), singleRowBlockWriter, sqlTimestamp.getEpochMicros()); + writeNativeValue(VARBINARY, singleRowBlockWriter, bytesData); + + rowBlockBuilder.closeEntry(); + rowEncoder.appendColumnValue(rowBlockBuilder.build(), 0); + + assertEquals(messageBuilder.build().toByteArray(), rowEncoder.toByteArray()); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testNestedStructuralDataTypes(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + Descriptor nestedDescriptor = descriptor.findFieldByName("nested_row").getMessageType(); + DynamicMessage.Builder nestedMessageBuilder = DynamicMessage.newBuilder(nestedDescriptor); + + Descriptor rowDescriptor = nestedDescriptor.findFieldByName("row").getMessageType(); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + nestedMessageBuilder.setField(nestedDescriptor.findFieldByName("nested_list"), ImmutableList.of(rowBuilder.build())); + + Descriptor mapDescriptor = nestedDescriptor.findFieldByName("nested_map").getMessageType(); + DynamicMessage.Builder mapBuilder = DynamicMessage.newBuilder(mapDescriptor); + mapBuilder.setField(mapDescriptor.findFieldByName("key"), "Key"); + mapBuilder.setField(mapDescriptor.findFieldByName("value"), rowBuilder.build()); + nestedMessageBuilder.setField(nestedDescriptor.findFieldByName("nested_map"), ImmutableList.of(mapBuilder.build())); + + nestedMessageBuilder.setField(nestedDescriptor.findFieldByName("row"), rowBuilder.build()); + + messageBuilder.setField(descriptor.findFieldByName("nested_row"), nestedMessageBuilder.build()); + + RowType rowType = RowType.from(ImmutableList.builder() + .add(RowType.field("string_column", createVarcharType(30000))) + .add(RowType.field("integer_column", INTEGER)) + .add(RowType.field("long_column", BIGINT)) + .add(RowType.field("double_column", DOUBLE)) + .add(RowType.field("float_column", REAL)) + .add(RowType.field("boolean_column", BOOLEAN)) + .add(RowType.field("number_column", createVarcharType(4))) + .add(RowType.field("timestamp_column", createTimestampType(6))) + .add(RowType.field("bytes_column", VARBINARY)) + .build()); + + List columnHandles = ImmutableList.of( + createEncoderColumnHandle( + "row", + RowType.from(ImmutableList.of( + RowType.field("nested_list", new ArrayType(rowType)), + RowType.field("nested_map", TESTING_TYPE_MANAGER.getType(mapType(VARCHAR.getTypeSignature(), rowType.getTypeSignature()))), + RowType.field("row", rowType))), + "nested_row")); + + RowEncoder rowEncoder = createRowEncoder("structural_datatypes.proto", columnHandles); + + BlockBuilder rowBlockBuilder = rowType + .createBlockBuilder(null, 1); + BlockBuilder singleRowBlockWriter = rowBlockBuilder.beginBlockEntry(); + writeNativeValue(VARCHAR, singleRowBlockWriter, Slices.utf8Slice(stringData)); + writeNativeValue(INTEGER, singleRowBlockWriter, integerData.longValue()); + writeNativeValue(BIGINT, singleRowBlockWriter, longData); + writeNativeValue(DOUBLE, singleRowBlockWriter, doubleData); + writeNativeValue(REAL, singleRowBlockWriter, (long) floatToIntBits(floatData)); + writeNativeValue(BOOLEAN, singleRowBlockWriter, booleanData); + writeNativeValue(VARCHAR, singleRowBlockWriter, enumData); + writeNativeValue(createTimestampType(6), singleRowBlockWriter, sqlTimestamp.getEpochMicros()); + writeNativeValue(VARBINARY, singleRowBlockWriter, bytesData); + rowBlockBuilder.closeEntry(); + + RowType nestedRowType = (RowType) columnHandles.get(0).getType(); + + MapType mapType = (MapType) nestedRowType.getTypeParameters().get(1); + BlockBuilder mapBlockBuilder = mapType.createBlockBuilder(null, 1); + Block mapBlock = mapType.createBlockFromKeyValue( + Optional.empty(), + new int[]{0, 1}, + nativeValueToBlock(VARCHAR, utf8Slice("Key")), + rowBlockBuilder.build()); + mapType.appendTo( + mapBlock, + 0, + mapBlockBuilder); + + Type listType = nestedRowType.getTypeParameters().get(0); + BlockBuilder listBlockBuilder = listType.createBlockBuilder(null, 1); + Block arrayBlock = fromElementBlock( + 1, + Optional.empty(), + new int[]{0, rowBlockBuilder.getPositionCount()}, + rowBlockBuilder.build()); + listType.appendTo(arrayBlock, 0, listBlockBuilder); + + BlockBuilder nestedBlockBuilder = nestedRowType.createBlockBuilder(null, 1); + Block rowBlock = fromFieldBlocks( + 1, + Optional.empty(), + new Block[]{listBlockBuilder.build(), mapBlockBuilder.build(), rowBlockBuilder.build()}); + nestedRowType.appendTo(rowBlock, 0, nestedBlockBuilder); + + rowEncoder.appendColumnValue(nestedBlockBuilder, 0); + + assertEquals(messageBuilder.build().toByteArray(), rowEncoder.toByteArray()); + } + + @Test(dataProvider = "allTypesDataProvider", dataProviderClass = ProtobufDataProviders.class) + public void testRowFlattening(String stringData, Integer integerData, Long longData, Double doubleData, Float floatData, Boolean booleanData, String enumData, SqlTimestamp sqlTimestamp, byte[] bytesData) + throws Exception + { + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + DynamicMessage.Builder messageBuilder = DynamicMessage.newBuilder(descriptor); + + Descriptor rowDescriptor = descriptor.findNestedTypeByName("Row"); + DynamicMessage.Builder rowBuilder = DynamicMessage.newBuilder(rowDescriptor); + rowBuilder.setField(rowDescriptor.findFieldByName("string_column"), stringData); + rowBuilder.setField(rowDescriptor.findFieldByName("integer_column"), integerData); + rowBuilder.setField(rowDescriptor.findFieldByName("long_column"), longData); + rowBuilder.setField(rowDescriptor.findFieldByName("double_column"), doubleData); + rowBuilder.setField(rowDescriptor.findFieldByName("float_column"), floatData); + rowBuilder.setField(rowDescriptor.findFieldByName("boolean_column"), booleanData); + rowBuilder.setField(rowDescriptor.findFieldByName("number_column"), descriptor.findEnumTypeByName("Number").findValueByName(enumData)); + rowBuilder.setField(rowDescriptor.findFieldByName("timestamp_column"), getTimestamp(sqlTimestamp)); + rowBuilder.setField(rowDescriptor.findFieldByName("bytes_column"), bytesData); + messageBuilder.setField(descriptor.findFieldByName("row"), rowBuilder.build()); + + RowEncoder rowEncoder = createRowEncoder( + "structural_datatypes.proto", + ImmutableList.of( + createEncoderColumnHandle("stringColumn", createVarcharType(100), "row/string_column"), + createEncoderColumnHandle("integerColumn", INTEGER, "row/integer_column"), + createEncoderColumnHandle("longColumn", BIGINT, "row/long_column"), + createEncoderColumnHandle("doubleColumn", DOUBLE, "row/double_column"), + createEncoderColumnHandle("floatColumn", REAL, "row/float_column"), + createEncoderColumnHandle("booleanColumn", BOOLEAN, "row/boolean_column"), + createEncoderColumnHandle("numberColumn", createVarcharType(4), "row/number_column"), + createEncoderColumnHandle("timestampColumn", createTimestampType(4), "row/timestamp_column"), + createEncoderColumnHandle("bytesColumn", VARBINARY, "row/bytes_column"))); + + rowEncoder.appendColumnValue(nativeValueToBlock(createVarcharType(5), utf8Slice(stringData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(INTEGER, integerData.longValue()), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(BIGINT, longData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(DOUBLE, doubleData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(REAL, (long) floatToIntBits(floatData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(BOOLEAN, booleanData), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(createVarcharType(4), utf8Slice(enumData)), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(createTimestampType(6), sqlTimestamp.getEpochMicros()), 0); + rowEncoder.appendColumnValue(nativeValueToBlock(VARBINARY, wrappedBuffer(bytesData)), 0); + + assertEquals(messageBuilder.build().toByteArray(), rowEncoder.toByteArray()); + } + + private Timestamp getTimestamp(SqlTimestamp sqlTimestamp) + { + return Timestamp.newBuilder() + .setSeconds(floorDiv(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND)) + .setNanos(floorMod(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND) + .build(); + } + + private RowEncoder createRowEncoder(String fileName, List columns) + throws Exception + { + return ENCODER_FACTORY.create(TestingConnectorSession.SESSION, Optional.of(getProtoFile("decoder/protobuf/" + fileName)), columns); + } + + private Descriptor getDescriptor(String fileName) + throws Exception + { + return getFileDescriptor(getProtoFile("decoder/protobuf/" + fileName)).findMessageTypeByName(DEFAULT_MESSAGE); + } + + private static EncoderColumnHandle createEncoderColumnHandle(String name, Type type, String mapping) + { + return new KafkaColumnHandle(name, type, mapping, null, null, false, false, false); + } +} diff --git a/plugin/trino-kafka/src/test/resources/protobuf/default_values.proto b/plugin/trino-kafka/src/test/resources/protobuf/default_values.proto new file mode 100644 index 000000000000..af1fe2819cee --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/default_values.proto @@ -0,0 +1,15 @@ +syntax = "proto2"; + +message schema { + optional string stringColumn = 1 [default = "Default"]; + optional uint32 integerColumn = 2 [default = 314]; + optional uint64 longColumn = 3 [default = 314]; + optional double doubleColumn = 4 [default = 3.14]; + optional float floatColumn = 5; + optional bool booleanColumn = 6; + enum Number { + ONE = 0; + TWO = 2; + }; + optional Number numberColumn = 7; +} diff --git a/plugin/trino-kafka/src/test/resources/protobuf/evolved_schema.proto b/plugin/trino-kafka/src/test/resources/protobuf/evolved_schema.proto new file mode 100644 index 000000000000..858e3b0f878e --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/evolved_schema.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +message schema { + string col_1 = 1; + uint64 col_2 = 2; + double col_3 = 3; +} diff --git a/plugin/trino-kafka/src/test/resources/protobuf/initial_schema.proto b/plugin/trino-kafka/src/test/resources/protobuf/initial_schema.proto new file mode 100644 index 000000000000..e5e37a7d7182 --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/initial_schema.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message schema { + string col_1 = 1 ; + uint64 col_2 = 2; +} diff --git a/plugin/trino-kafka/src/test/resources/protobuf/key_schema.proto b/plugin/trino-kafka/src/test/resources/protobuf/key_schema.proto new file mode 100644 index 000000000000..fde18b2877ed --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/key_schema.proto @@ -0,0 +1,5 @@ +syntax = "proto3"; + +message schema { + uint64 key = 1; +} diff --git a/plugin/trino-kafka/src/test/resources/protobuf/timestamps.proto b/plugin/trino-kafka/src/test/resources/protobuf/timestamps.proto new file mode 100644 index 000000000000..e3c04e27f977 --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/timestamps.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + google.protobuf.Timestamp precision_0 = 1; + google.protobuf.Timestamp precision_1 = 2; + google.protobuf.Timestamp precision_2 = 3; + google.protobuf.Timestamp precision_3 = 4; + google.protobuf.Timestamp precision_4 = 5; + google.protobuf.Timestamp precision_5 = 6; + google.protobuf.Timestamp precision_6 = 7; +} diff --git a/pom.xml b/pom.xml index 44918bfe3e6c..146e1924f6cf 100644 --- a/pom.xml +++ b/pom.xml @@ -551,6 +551,13 @@ ${project.version} + + io.trino + trino-record-decoder + test-jar + ${project.version} + + io.trino trino-resource-group-managers @@ -1367,7 +1374,7 @@ io.confluent kafka-json-schema-serializer ${dep.confluent.version} - + test @@ -1385,6 +1392,22 @@ + + io.confluent + kafka-protobuf-provider + ${dep.confluent.version} + + provided + + + + io.confluent + kafka-protobuf-serializer + ${dep.confluent.version} + + test + + io.confluent kafka-schema-registry-client diff --git a/testing/trino-product-tests-launcher/pom.xml b/testing/trino-product-tests-launcher/pom.xml index afaa4e4f9667..b73acdbf4376 100644 --- a/testing/trino-product-tests-launcher/pom.xml +++ b/testing/trino-product-tests-launcher/pom.xml @@ -188,6 +188,31 @@ + + org.apache.maven.plugins + maven-dependency-plugin + + + copy + package + + copy + + + false + + + io.confluent + kafka-protobuf-provider + jar + ${project.build.directory} + + + + + + + org.skife.maven really-executable-jar-maven-plugin diff --git a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/common/Kafka.java b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/common/Kafka.java index 75c6b4aab97f..6b52a3bca8f1 100644 --- a/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/common/Kafka.java +++ b/testing/trino-product-tests-launcher/src/main/java/io/trino/tests/product/launcher/env/common/Kafka.java @@ -24,6 +24,7 @@ import javax.inject.Inject; +import java.io.File; import java.time.Duration; import static io.trino.tests.product.launcher.docker.ContainerUtil.forSelectedPorts; @@ -31,6 +32,7 @@ import static io.trino.tests.product.launcher.env.common.Standard.CONTAINER_TRINO_ETC; import static java.util.Objects.requireNonNull; import static org.testcontainers.containers.wait.strategy.Wait.forLogMessage; +import static org.testcontainers.utility.MountableFile.forClasspathResource; import static org.testcontainers.utility.MountableFile.forHostPath; public class Kafka @@ -38,6 +40,7 @@ public class Kafka { private static final String CONFLUENT_VERSION = "5.5.2"; private static final int SCHEMA_REGISTRY_PORT = 8081; + private static final File KAFKA_PROTOBUF_PROVIDER = new File("testing/trino-product-tests-launcher/target/kafka-protobuf-provider-5.5.2.jar"); static final String KAFKA = "kafka"; static final String SCHEMA_REGISTRY = "schema-registry"; static final String ZOOKEEPER = "zookeeper"; @@ -64,7 +67,10 @@ public void extendEnvironment(Environment.Builder builder) builder.configureContainers(container -> { if (isTrinoContainer(container.getLogicalName())) { MountableFile logConfigFile = forHostPath(configDir.getPath("log.properties")); - container.withCopyFileToContainer(logConfigFile, CONTAINER_TRINO_ETC + "/log.properties"); + container + .withCopyFileToContainer(logConfigFile, CONTAINER_TRINO_ETC + "/log.properties") + .withCopyFileToContainer(forHostPath(KAFKA_PROTOBUF_PROVIDER.getAbsolutePath()), "/docker/kafka-protobuf-provider/kafka-protobuf-provider.jar") + .withCopyFileToContainer(forClasspathResource("install-kafka-protobuf-provider.sh", 0755), "/docker/presto-init.d/install-kafka-protobuf-provider.sh"); } }); diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-sasl-plaintext/kafka.properties b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-sasl-plaintext/kafka.properties index b791b6839447..f9e1ec21ba2d 100644 --- a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-sasl-plaintext/kafka.properties +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-sasl-plaintext/kafka.properties @@ -14,7 +14,11 @@ kafka.table-names=product_tests.read_simple_key_and_value,\ product_tests.write_structural_datatype_avro,\ product_tests.pushdown_partition,\ product_tests.pushdown_offset,\ - product_tests.pushdown_create_time + product_tests.pushdown_create_time,\ + product_tests.all_datatypes_protobuf,\ + product_tests.structural_datatype_protobuf,\ + product_tests.read_basic_datatypes_protobuf,\ + product_tests.read_basic_structural_datatypes_protobuf kafka.nodes=kafka:9092 kafka.table-description-dir=/docker/presto-product-tests/conf/presto/etc/catalog/kafka kafka.config.resources=/docker/presto-product-tests/conf/presto/etc/kafka-configuration.properties diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-ssl/kafka.properties b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-ssl/kafka.properties index dd501fe1c792..ccfe80589327 100644 --- a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-ssl/kafka.properties +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka-ssl/kafka.properties @@ -14,7 +14,11 @@ kafka.table-names=product_tests.read_simple_key_and_value,\ product_tests.write_structural_datatype_avro,\ product_tests.pushdown_partition,\ product_tests.pushdown_offset,\ - product_tests.pushdown_create_time + product_tests.pushdown_create_time,\ + product_tests.all_datatypes_protobuf,\ + product_tests.structural_datatype_protobuf,\ + product_tests.read_basic_datatypes_protobuf,\ + product_tests.read_basic_structural_datatypes_protobuf kafka.nodes=kafka:9092 kafka.table-description-dir=/docker/presto-product-tests/conf/presto/etc/catalog/kafka kafka.security-protocol=SSL diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka/kafka.properties b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka/kafka.properties index d2882a96dd0c..bd6631b76e9c 100644 --- a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka/kafka.properties +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/environment/multinode-kafka/kafka.properties @@ -14,6 +14,10 @@ kafka.table-names=product_tests.read_simple_key_and_value,\ product_tests.write_structural_datatype_avro,\ product_tests.pushdown_partition,\ product_tests.pushdown_offset,\ - product_tests.pushdown_create_time + product_tests.pushdown_create_time,\ + product_tests.all_datatypes_protobuf,\ + product_tests.structural_datatype_protobuf,\ + product_tests.read_basic_datatypes_protobuf,\ + product_tests.read_basic_structural_datatypes_protobuf kafka.nodes=kafka:9092 kafka.table-description-dir=/docker/presto-product-tests/conf/presto/etc/catalog/kafka diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto new file mode 100644 index 000000000000..104ad559fa8e --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto @@ -0,0 +1,19 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + string a_varchar = 1; + uint32 b_integer = 2; + uint64 c_bigint = 3; + double d_double = 4; + float e_float = 5; + bool f_boolean = 6; + enum Number { + ZERO = 0; + ONE = 1; + TWO = 2; + }; + Number g_enum = 7; + google.protobuf.Timestamp h_timestamp = 8; +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes_protobuf.json b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes_protobuf.json new file mode 100644 index 000000000000..06027efe66c7 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes_protobuf.json @@ -0,0 +1,55 @@ +{ + "tableName": "all_datatypes_protobuf", + "schemaName": "product_tests", + "topicName": "all_datatypes_protobuf", + "message": { + "dataFormat": "protobuf", + "dataSchema": "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto", + "fields": [ + { + "name": "h_varchar", + "type": "VARCHAR", + "mapping": "a_varchar" + }, + { + "name": "g_integer", + "type": "INTEGER", + "mapping": "b_integer" + }, + { + "name": "f_bigint", + "type": "BIGINT", + "mapping": "c_bigint" + }, + { + "name": "e_double", + "type": "DOUBLE", + "mapping": "d_double" + }, + { + "name": "d_real", + "type": "REAL", + "mapping": "e_float" + }, + { + "name": "c_boolean", + "type": "BOOLEAN", + "mapping": "f_boolean" + }, + { + "name": "b_enum", + "type": "VARCHAR", + "mapping": "g_enum" + }, + { + "name": "a_timestamp", + "type": "TIMESTAMP(6)", + "mapping": "h_timestamp" + } + ] + }, + "key": { + "dataFormat": "raw", + "fields": [] + } +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto new file mode 100644 index 000000000000..8cb1b8434468 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto @@ -0,0 +1,10 @@ +syntax = "proto3"; + +message schema { + string a_varchar = 1; + uint32 b_integer = 2; + uint64 c_bigint = 3; + double d_double = 4; + float e_float = 5; + bool f_boolean = 6; +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto new file mode 100644 index 000000000000..1a1530dd2664 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message schema { + repeated uint64 a_array = 1; + map a_map = 2; +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_datatypes_protobuf.json b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_datatypes_protobuf.json new file mode 100644 index 000000000000..3318bf2b881c --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_datatypes_protobuf.json @@ -0,0 +1,41 @@ +{ + "tableName": "read_basic_datatypes_protobuf", + "schemaName": "product_tests", + "topicName": "read_basic_datatypes_protobuf", + "message": { + "dataFormat": "protobuf", + "dataSchema": "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto", + "fields": [ + { + "name": "a_varchar", + "type": "VARCHAR", + "mapping": "a_varchar" + }, + { + "name": "b_integer", + "type": "INTEGER", + "mapping": "b_integer" + }, + { + "name": "c_bigint", + "type": "BIGINT", + "mapping": "c_bigint" + }, + { + "name": "d_double", + "type": "DOUBLE", + "mapping": "d_double" + }, + { + "name": "e_real", + "type": "REAL", + "mapping": "e_float" + }, + { + "name": "f_boolean", + "type": "BOOLEAN", + "mapping": "f_boolean" + } + ] + } +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_structural_datatypes_protobuf.json b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_structural_datatypes_protobuf.json new file mode 100644 index 000000000000..5a6296aff4e3 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/read_basic_structural_datatypes_protobuf.json @@ -0,0 +1,21 @@ +{ + "tableName": "read_basic_structural_datatypes_protobuf", + "schemaName": "product_tests", + "topicName": "read_basic_structural_datatypes_protobuf", + "message": { + "dataFormat": "protobuf", + "dataSchema": "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto", + "fields": [ + { + "name": "c_array", + "type": "ARRAY(BIGINT)", + "mapping": "a_array" + }, + { + "name": "c_map", + "type": "MAP(VARCHAR, DOUBLE)", + "mapping": "a_map" + } + ] + } +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype.proto b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype.proto new file mode 100644 index 000000000000..a00efc706c43 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype.proto @@ -0,0 +1,28 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + message SimpleRow { + string simple_string = 1; + }; + repeated SimpleRow a_array = 1; + map b_map = 2; + enum Number { + ZERO = 0; + ONE = 1; + TWO = 2; + }; + message Row { + string a_string = 1; + uint32 b_integer = 2; + uint64 c_bigint = 3; + double d_double = 4; + float e_float = 5; + bool f_boolean = 6; + Number g_enum = 7; + google.protobuf.Timestamp h_timestamp = 8; + SimpleRow simple_row = 9; + }; + Row c_row = 3; +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype_protobuf.json b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype_protobuf.json new file mode 100644 index 000000000000..5404a69266f4 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype_protobuf.json @@ -0,0 +1,55 @@ +{ + "tableName": "structural_datatype_protobuf", + "schemaName": "product_tests", + "topicName": "structural_datatype_protobuf", + "message": { + "dataFormat": "protobuf", + "dataSchema": "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/structural_datatype.proto", + "fields": [ + { + "name": "c_array", + "type": "ARRAY(ROW(simple_string VARCHAR))", + "mapping": "a_array" + }, + { + "name": "b_map", + "type": "MAP(VARCHAR, ROW(simple_string VARCHAR))", + "mapping": "b_map" + }, + { + "name": "a_row", + "type": "ROW(d_double DOUBLE, e_float REAL, g_enum VARCHAR)", + "mapping": "c_row" + }, + { + "name": "a_string", + "type": "VARCHAR", + "mapping": "c_row/a_string" + }, + { + "name": "c_integer", + "type": "INTEGER", + "mapping": "c_row/b_integer" + }, + { + "name": "c_bigint", + "type": "BIGINT", + "mapping": "c_row/c_bigint" + }, + { + "name": "d_row", + "type": "ROW(simple_string VARCHAR)", + "mapping": "c_row/simple_row" + }, + { + "name": "e_timestamp", + "type": "TIMESTAMP(6)", + "mapping": "c_row/h_timestamp" + } + ] + }, + "key": { + "dataFormat": "raw", + "fields": [] + } +} diff --git a/testing/trino-product-tests-launcher/src/main/resources/install-kafka-protobuf-provider.sh b/testing/trino-product-tests-launcher/src/main/resources/install-kafka-protobuf-provider.sh new file mode 100644 index 000000000000..3e81ba2c06e6 --- /dev/null +++ b/testing/trino-product-tests-launcher/src/main/resources/install-kafka-protobuf-provider.sh @@ -0,0 +1,3 @@ +#!/bin/bash +set -xeuo pipefail +cp --no-clobber --verbose /docker/kafka-protobuf-provider/* /docker/presto-server/plugin/kafka diff --git a/testing/trino-product-tests/pom.xml b/testing/trino-product-tests/pom.xml index bf9990966b51..3fa9e3e66115 100644 --- a/testing/trino-product-tests/pom.xml +++ b/testing/trino-product-tests/pom.xml @@ -151,6 +151,11 @@ guice + + com.google.protobuf + protobuf-java + + com.squareup.okhttp3 okhttp @@ -166,6 +171,19 @@ okhttp-urlconnection + + io.confluent + kafka-protobuf-provider + + compile + + + com.squareup.okio + okio + + + + io.confluent kafka-schema-registry-client diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobuf.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobuf.java new file mode 100644 index 000000000000..7d4720f166af --- /dev/null +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobuf.java @@ -0,0 +1,128 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.product.kafka; + +import com.google.common.collect.ImmutableList; +import io.trino.tempto.ProductTest; +import io.trino.tempto.fulfillment.table.TableManager; +import io.trino.tempto.fulfillment.table.kafka.KafkaTableDefinition; +import io.trino.tempto.fulfillment.table.kafka.KafkaTableManager; +import io.trino.tempto.fulfillment.table.kafka.ListKafkaDataSource; +import org.testng.annotations.Test; + +import java.sql.SQLException; +import java.sql.Timestamp; + +import static io.trino.tempto.assertions.QueryAssert.Row.row; +import static io.trino.tempto.assertions.QueryAssert.assertQueryFailure; +import static io.trino.tempto.assertions.QueryAssert.assertThat; +import static io.trino.tempto.context.ThreadLocalTestContextHolder.testContext; +import static io.trino.tempto.fulfillment.table.TableHandle.tableHandle; +import static io.trino.tests.product.TestGroups.KAFKA; +import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; +import static io.trino.tests.product.utils.QueryExecutors.onTrino; +import static java.lang.String.format; + +public class TestKafkaProtobuf + extends ProductTest +{ + private static final String KAFKA_CATALOG = "kafka"; + private static final String KAFKA_SCHEMA = "product_tests"; + + private static final String ALL_DATATYPES_PROTOBUF_TABLE_NAME = "all_datatypes_protobuf"; + private static final String ALL_DATATYPES_PROTOBUF_TOPIC_NAME = "write_all_datatypes_protobuf"; + + private static final String STRUCTURAL_PROTOBUF_TABLE_NAME = "structural_datatype_protobuf"; + private static final String STRUCTURAL_PROTOBUF_TOPIC_NAME = "structural_datatype_protobuf"; + + private static void createProtobufTable(String tableName, String topicName) + { + KafkaTableDefinition tableDefinition = new KafkaTableDefinition( + tableName, + topicName, + new ListKafkaDataSource(ImmutableList.of()), + 1, + 1); + KafkaTableManager kafkaTableManager = (KafkaTableManager) testContext().getDependency(TableManager.class, "kafka"); + kafkaTableManager.createImmutable(tableDefinition, tableHandle(tableName).inSchema(KAFKA_SCHEMA)); + } + + @Test(groups = {KAFKA, PROFILE_SPECIFIC_TESTS}) + public void testInsertAllDataType() + { + createProtobufTable(ALL_DATATYPES_PROTOBUF_TABLE_NAME, ALL_DATATYPES_PROTOBUF_TOPIC_NAME); + assertThat(onTrino().executeQuery(format( + "INSERT INTO %s.%s.%s VALUES " + + "('Chennai', 314, 9223372036854775807, 1234567890.123456789, 3.14, true, 'ZERO', TIMESTAMP '2020-12-21 15:45:00.012345')," + + "('TamilNadu', -314, -9223372036854775808, -1234567890.123456789, -3.14, false, 'ONE', TIMESTAMP '1970-01-01 15:45:00.012345'), " + + "('India', 314, 9223372036854775807, 1234567890.123456789, 3.14, false, 'TWO', TIMESTAMP '0001-01-01 00:00:00.000001')", + KAFKA_CATALOG, + KAFKA_SCHEMA, + ALL_DATATYPES_PROTOBUF_TABLE_NAME))) + .updatedRowsCountIsEqualTo(3); + + assertThat(onTrino().executeQuery(format( + "SELECT * FROM %s.%s.%s", + KAFKA_CATALOG, + KAFKA_SCHEMA, + ALL_DATATYPES_PROTOBUF_TABLE_NAME))) + .containsOnly( + row("Chennai", 314, 9223372036854775807L, 1234567890.123456789, 3.14f, true, "ZERO", Timestamp.valueOf("2020-12-21 15:45:00.012345")), + row("TamilNadu", -314, -9223372036854775808L, -1234567890.123456789, -3.14f, false, "ONE", Timestamp.valueOf("1970-01-01 15:45:00.012345")), + row("India", 314, 9223372036854775807L, 1234567890.123456789, 3.14f, false, "TWO", Timestamp.valueOf("0001-01-01 00:00:00.000001"))); + + assertQueryFailure(() -> onTrino().executeQuery(format( + "INSERT INTO %s.%s.%s (h_varchar) VALUES ('Chennai')", KAFKA_CATALOG, KAFKA_SCHEMA, ALL_DATATYPES_PROTOBUF_TABLE_NAME))) + .isInstanceOf(SQLException.class) + .hasMessageMatching("Query failed \\(.+\\): Protobuf doesn't support serializing null values"); + } + + @Test(groups = {KAFKA, PROFILE_SPECIFIC_TESTS}) + public void testInsertStructuralDataType() + { + createProtobufTable(STRUCTURAL_PROTOBUF_TABLE_NAME, STRUCTURAL_PROTOBUF_TOPIC_NAME); + assertThat(onTrino().executeQuery(format( + "INSERT INTO %s.%s.%s VALUES " + + "(ARRAY[CAST(ROW('Entry1') AS ROW(simple_string VARCHAR))], " + + "map_from_entries(ARRAY[('key1', CAST(ROW('value1') AS ROW(simple_string VARCHAR)))]), " + + "CAST(ROW(1234567890.123456789, 3.14, 'ONE') AS ROW(d_double DOUBLE, e_float REAL, g_enum VARCHAR)), " + + "'Chennai', " + + "314, " + + "9223372036854775807, " + + "CAST(ROW('Entry2') AS ROW(simple_string VARCHAR)), " + + "TIMESTAMP '2020-12-21 15:45:00.012345')", + KAFKA_CATALOG, + KAFKA_SCHEMA, + STRUCTURAL_PROTOBUF_TABLE_NAME))) + .updatedRowsCountIsEqualTo(1); + + assertThat(onTrino().executeQuery(format( + "SELECT c_array[1].simple_string, b_map['key1'].simple_string, a_row.d_double, a_row.e_float, a_row.g_enum, a_string, c_integer, c_bigint, d_row.simple_string, e_timestamp FROM %s.%s.%s", + KAFKA_CATALOG, + KAFKA_SCHEMA, + STRUCTURAL_PROTOBUF_TABLE_NAME))) + .containsOnly( + row( + "Entry1", + "value1", + 1234567890.1234567890, + 3.14f, + "ONE", + "Chennai", + 314, + 9223372036854775807L, + "Entry2", + Timestamp.valueOf("2020-12-21 15:45:00.012345"))); + } +} diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobufReads.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobufReads.java new file mode 100644 index 000000000000..f46a27ee0db4 --- /dev/null +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/kafka/TestKafkaProtobufReads.java @@ -0,0 +1,317 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests.product.kafka; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.io.Resources; +import com.google.common.primitives.Ints; +import com.google.protobuf.Descriptors.Descriptor; +import com.google.protobuf.Descriptors.FieldDescriptor; +import com.google.protobuf.DynamicMessage; +import io.airlift.units.Duration; +import io.confluent.kafka.schemaregistry.client.rest.entities.SchemaReference; +import io.confluent.kafka.schemaregistry.protobuf.ProtobufSchema; +import io.trino.tempto.ProductTest; +import io.trino.tempto.fulfillment.table.TableManager; +import io.trino.tempto.fulfillment.table.kafka.KafkaMessage; +import io.trino.tempto.fulfillment.table.kafka.KafkaTableDefinition; +import io.trino.tempto.fulfillment.table.kafka.KafkaTableManager; +import io.trino.tempto.fulfillment.table.kafka.ListKafkaDataSource; +import io.trino.tempto.query.QueryResult; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.io.ByteArrayOutputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.sql.Timestamp; +import java.time.LocalDateTime; +import java.time.ZoneOffset; +import java.util.Map; + +import static io.trino.tempto.assertions.QueryAssert.Row.row; +import static io.trino.tempto.assertions.QueryAssert.assertThat; +import static io.trino.tempto.context.ThreadLocalTestContextHolder.testContext; +import static io.trino.tempto.fulfillment.table.TableHandle.tableHandle; +import static io.trino.tempto.fulfillment.table.kafka.KafkaMessageContentsBuilder.contentsBuilder; +import static io.trino.tests.product.TestGroups.KAFKA; +import static io.trino.tests.product.TestGroups.PROFILE_SPECIFIC_TESTS; +import static io.trino.tests.product.utils.QueryAssertions.assertEventually; +import static io.trino.tests.product.utils.QueryExecutors.onTrino; +import static io.trino.tests.product.utils.SchemaRegistryClientUtils.getSchemaRegistryClient; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; + +@Test(singleThreaded = true) +public class TestKafkaProtobufReads + extends ProductTest +{ + private static final String KAFKA_SCHEMA = "product_tests"; + + private static final String BASIC_DATATYPES_PROTOBUF_TOPIC_NAME = "read_basic_datatypes_protobuf"; + private static final String BASIC_DATATYPES_SCHEMA_PATH = "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_datatypes.proto"; + + private static final String BASIC_STRUCTURAL_PROTOBUF_TOPIC_NAME = "read_basic_structural_datatypes_protobuf"; + private static final String BASIC_STRUCTURAL_SCHEMA_PATH = "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/basic_structural_datatypes.proto"; + + private static final String ALL_DATATYPES_PROTOBUF_TOPIC_SCHEMA_REGISTRY = "all_datatypes_protobuf_schema_registry"; + private static final String ALL_DATATYPES_SCHEMA_PATH = "/docker/presto-product-tests/conf/presto/etc/catalog/kafka/all_datatypes.proto"; + + @Test(groups = {KAFKA, PROFILE_SPECIFIC_TESTS}, dataProvider = "catalogs") + public void testSelectPrimitiveDataType(KafkaCatalog kafkaCatalog, MessageSerializer messageSerializer) + throws Exception + { + Map record = ImmutableMap.builder() + .put("a_varchar", "foobar") + .put("b_integer", 314) + .put("c_bigint", 9223372036854775807L) + .put("d_double", 1234567890.123456789) + .put("e_float", 3.14f) + .put("f_boolean", true) + .buildOrThrow(); + String topicName = BASIC_DATATYPES_PROTOBUF_TOPIC_NAME + kafkaCatalog.getTopicNameSuffix(); + createProtobufTable(BASIC_DATATYPES_SCHEMA_PATH, BASIC_DATATYPES_PROTOBUF_TOPIC_NAME, topicName, record, messageSerializer); + + assertEventually( + new Duration(30, SECONDS), + () -> { + QueryResult queryResult = onTrino().executeQuery(format("select * from %s.%s", kafkaCatalog.getCatalogName(), KAFKA_SCHEMA + "." + topicName)); + assertThat(queryResult).containsOnly(row( + "foobar", + 314, + 9223372036854775807L, + 1234567890.123456789, + 3.14f, + true)); + }); + } + + @Test(groups = {KAFKA, PROFILE_SPECIFIC_TESTS}, dataProvider = "catalogs") + public void testSelectStructuralDataType(KafkaCatalog kafkaCatalog, MessageSerializer messageSerializer) + throws Exception + { + ImmutableMap record = ImmutableMap.of( + "a_array", ImmutableList.of(100L, 101L), + "a_map", ImmutableMap.of( + "key", "key1", + "value", 1234567890.123456789)); + String topicName = BASIC_STRUCTURAL_PROTOBUF_TOPIC_NAME + kafkaCatalog.getTopicNameSuffix(); + createProtobufTable(BASIC_STRUCTURAL_SCHEMA_PATH, BASIC_STRUCTURAL_PROTOBUF_TOPIC_NAME, topicName, record, messageSerializer); + assertEventually( + new Duration(30, SECONDS), + () -> { + QueryResult queryResult = onTrino().executeQuery(format( + "SELECT a[1], a[2], m['key1'] FROM (SELECT %s as a, %s as m FROM %s.%s) t", + kafkaCatalog.isColumnMappingSupported() ? "c_array" : "a_array", + kafkaCatalog.isColumnMappingSupported() ? "c_map" : "a_map", + kafkaCatalog.getCatalogName(), + KAFKA_SCHEMA + "." + topicName)); + assertThat(queryResult).containsOnly(row(100L, 101L, 1234567890.123456789)); + }); + } + + @DataProvider + public static Object[][] catalogs() + { + return new Object[][] { + { + new KafkaCatalog("kafka", "", true), new ProtobufMessageSerializer(), + }, + { + new KafkaCatalog("kafka_schema_registry", "_schema_registry", false), new SchemaRegistryProtobufMessageSerializer(), + }, + }; + } + + private static final class KafkaCatalog + { + private final String catalogName; + private final String topicNameSuffix; + private final boolean columnMappingSupported; + + private KafkaCatalog(String catalogName, String topicNameSuffix, boolean columnMappingSupported) + { + this.catalogName = requireNonNull(catalogName, "catalogName is null"); + this.topicNameSuffix = requireNonNull(topicNameSuffix, "topicNameSuffix is null"); + this.columnMappingSupported = columnMappingSupported; + } + + public String getCatalogName() + { + return catalogName; + } + + public String getTopicNameSuffix() + { + return topicNameSuffix; + } + + public boolean isColumnMappingSupported() + { + return columnMappingSupported; + } + + @Override + public String toString() + { + return catalogName; + } + } + + @Test(groups = {KAFKA, PROFILE_SPECIFIC_TESTS}) + public void testProtobufWithSchemaReferences() + throws Exception + { + String timestampTopic = "timestamp"; + String timestampProtoFile = "google/protobuf/timestamp.proto"; + ProtobufSchema baseSchema = new ProtobufSchema( + Resources.toString(Resources.getResource(TestKafkaProtobufReads.class, "/" + timestampProtoFile), UTF_8), + ImmutableList.of(), + ImmutableMap.of(), + null, + timestampProtoFile); + + getSchemaRegistryClient().register(timestampTopic, baseSchema); + + ProtobufSchema actualSchema = new ProtobufSchema( + Files.readString(Path.of(ALL_DATATYPES_SCHEMA_PATH)), + ImmutableList.of(new SchemaReference(baseSchema.name(), timestampTopic, 1)), + ImmutableMap.of(timestampProtoFile, baseSchema.canonicalString()), + null, + null); + + LocalDateTime timestamp = LocalDateTime.parse("2020-12-12T15:35:45.923"); + com.google.protobuf.Timestamp timestampProto = com.google.protobuf.Timestamp.newBuilder() + .setSeconds(timestamp.toEpochSecond(ZoneOffset.UTC)) + .setNanos(timestamp.getNano()) + .build(); + + Map record = ImmutableMap.builder() + .put("a_varchar", "foobar") + .put("b_integer", 2) + .put("c_bigint", 9223372036854775807L) + .put("d_double", 1234567890.123456789) + .put("e_float", 3.14f) + .put("f_boolean", true) + .put("h_timestamp", timestampProto) + .buildOrThrow(); + + // This is a bit hacky as KafkaTableManager relies on kafka catalog's tables for inserting data into a given topic + createProtobufTable(actualSchema, BASIC_DATATYPES_PROTOBUF_TOPIC_NAME, ALL_DATATYPES_PROTOBUF_TOPIC_SCHEMA_REGISTRY, record, new SchemaRegistryProtobufMessageSerializer()); + + assertEventually( + new Duration(30, SECONDS), + () -> { + QueryResult queryResult = onTrino().executeQuery(format("select * from kafka_schema_registry.%s.%s", KAFKA_SCHEMA, ALL_DATATYPES_PROTOBUF_TOPIC_SCHEMA_REGISTRY)); + assertThat(queryResult).containsOnly(row( + "foobar", + 2, + 9223372036854775807L, + 1234567890.123456789, + 3.14f, + true, + "ZERO", + Timestamp.valueOf(timestamp))); + }); + } + + private static void createProtobufTable(String schemaPath, String tableName, String topicName, Map record, MessageSerializer messageSerializer) + throws Exception + { + createProtobufTable(new ProtobufSchema(Files.readString(Path.of(schemaPath))), tableName, topicName, record, messageSerializer); + } + + private static void createProtobufTable(ProtobufSchema protobufSchema, String tableName, String topicName, Map record, MessageSerializer messageSerializer) + throws Exception + { + byte[] protobufData = messageSerializer.serialize(topicName, protobufSchema, record); + + KafkaTableDefinition tableDefinition = new KafkaTableDefinition( + KAFKA_SCHEMA + "." + tableName, + topicName, + new ListKafkaDataSource(ImmutableList.of( + new KafkaMessage( + contentsBuilder() + .appendBytes(protobufData) + .build()))), + 1, + 1); + KafkaTableManager kafkaTableManager = (KafkaTableManager) testContext().getDependency(TableManager.class, "kafka"); + kafkaTableManager.createImmutable(tableDefinition, tableHandle(tableName).inSchema(KAFKA_SCHEMA)); + } + + @FunctionalInterface + private interface MessageSerializer + { + byte[] serialize(String topic, ProtobufSchema protobufSchema, Map values) + throws Exception; + } + + private static final class ProtobufMessageSerializer + implements MessageSerializer + { + @Override + public byte[] serialize(String topic, ProtobufSchema protobufSchema, Map values) + { + return buildDynamicMessage(protobufSchema.toDescriptor(), values).toByteArray(); + } + } + + private static final class SchemaRegistryProtobufMessageSerializer + implements MessageSerializer + { + @Override + public byte[] serialize(String topic, ProtobufSchema protobufSchema, Map values) + throws Exception + { + try (ByteArrayOutputStream out = new ByteArrayOutputStream()) { + // Write magic byte + out.write((byte) 0); + + // Write SchemaId + int schemaId = getSchemaRegistryClient().register( + topic + "-value", + protobufSchema); + out.write(Ints.toByteArray(schemaId)); + + // Write empty MessageIndexes + out.write((byte) 0); + + out.write(buildDynamicMessage(protobufSchema.toDescriptor(), values).toByteArray()); + return out.toByteArray(); + } + } + } + + private static DynamicMessage buildDynamicMessage(Descriptor descriptor, Map data) + { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + for (Map.Entry entry : data.entrySet()) { + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(entry.getKey()); + if (entry.getValue() instanceof Map) { + builder.setField( + fieldDescriptor, + ImmutableList.of( + buildDynamicMessage(fieldDescriptor.getMessageType(), (Map) entry.getValue()))); + } + else { + builder.setField(fieldDescriptor, entry.getValue()); + } + } + return builder.build(); + } +}