diff --git a/.github/labeler.yml b/.github/labeler.yml index cf1d2a7117203..84dfa35f2627e 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -155,3 +155,6 @@ CONNECT: - "connector/connect/**/*" - "**/sql/sparkconnect/**/*" - "python/pyspark/sql/**/connect/**/*" +PROTOBUF: + - "connector/protobuf/**/*" + - "python/pyspark/sql/protobuf/**/*" \ No newline at end of file diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml new file mode 100644 index 0000000000000..0515f128b8d63 --- /dev/null +++ b/connector/protobuf/pom.xml @@ -0,0 +1,115 @@ + + + + + 4.0.0 + + org.apache.spark + spark-parent_2.12 + 3.4.0-SNAPSHOT + ../../pom.xml + + + spark-protobuf_2.12 + + protobuf + 3.21.1 + + jar + Spark Protobuf + https://spark.apache.org/ + + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + provided + + + org.apache.spark + spark-core_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-catalyst_${scala.binary.version} + ${project.version} + test-jar + test + + + org.apache.spark + spark-sql_${scala.binary.version} + ${project.version} + test-jar + test + + + org.scalacheck + scalacheck_${scala.binary.version} + test + + + org.apache.spark + spark-tags_${scala.binary.version} + + + + com.google.protobuf + protobuf-java + ${protobuf.version} + compile + + + + + target/scala-${scala.binary.version}/classes + target/scala-${scala.binary.version}/test-classes + + + org.apache.maven.plugins + maven-shade-plugin + + false + + + com.google.protobuf:* + + + + + com.google.protobuf + ${spark.shade.packageName}.spark-protobuf.protobuf + + com.google.protobuf.** + + + + + + + + diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala new file mode 100644 index 0000000000000..145100268c232 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/CatalystDataToProtobuf.scala @@ -0,0 +1,54 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.expressions.{Expression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.types.{BinaryType, DataType} + +private[protobuf] case class CatalystDataToProtobuf( + child: Expression, + descFilePath: String, + messageName: String) + extends UnaryExpression { + + override def dataType: DataType = BinaryType + + @transient private lazy val protoType = + ProtobufUtils.buildDescriptor(descFilePath, messageName) + + @transient private lazy val serializer = + new ProtobufSerializer(child.dataType, protoType, child.nullable) + + override def nullSafeEval(input: Any): Any = { + val dynamicMessage = serializer.serialize(input).asInstanceOf[DynamicMessage] + dynamicMessage.toByteArray + } + + override def prettyName: String = "to_protobuf" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + defineCodeGen(ctx, ev, input => s"(byte[]) $expr.nullSafeEval($input)") + } + + override protected def withNewChildInternal(newChild: Expression): CatalystDataToProtobuf = + copy(child = newChild) +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala new file mode 100644 index 0000000000000..f08f876799723 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDataToCatalyst.scala @@ -0,0 +1,160 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal + +import com.google.protobuf.DynamicMessage + +import org.apache.spark.SparkException +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, SpecificInternalRow, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode} +import org.apache.spark.sql.catalyst.util.{FailFastMode, ParseMode, PermissiveMode} +import org.apache.spark.sql.protobuf.utils.{ProtobufOptions, ProtobufUtils, SchemaConverters} +import org.apache.spark.sql.types.{AbstractDataType, BinaryType, DataType, StructType} + +private[protobuf] case class ProtobufDataToCatalyst( + child: Expression, + descFilePath: String, + messageName: String, + options: Map[String, String]) + extends UnaryExpression + with ExpectsInputTypes { + + override def inputTypes: Seq[AbstractDataType] = Seq(BinaryType) + + override lazy val dataType: DataType = { + val dt = SchemaConverters.toSqlType(messageDescriptor).dataType + parseMode match { + // With PermissiveMode, the output Catalyst row might contain columns of null values for + // corrupt records, even if some of the columns are not nullable in the user-provided schema. + // Therefore we force the schema to be all nullable here. + case PermissiveMode => dt.asNullable + case _ => dt + } + } + + override def nullable: Boolean = true + + private lazy val protobufOptions = ProtobufOptions(options) + + @transient private lazy val messageDescriptor = + ProtobufUtils.buildDescriptor(descFilePath, messageName) + + @transient private lazy val fieldsNumbers = + messageDescriptor.getFields.asScala.map(f => f.getNumber) + + @transient private lazy val deserializer = new ProtobufDeserializer(messageDescriptor, dataType) + + @transient private var result: DynamicMessage = _ + + @transient private lazy val parseMode: ParseMode = { + val mode = protobufOptions.parseMode + if (mode != PermissiveMode && mode != FailFastMode) { + throw new AnalysisException(unacceptableModeMessage(mode.name)) + } + mode + } + + private def unacceptableModeMessage(name: String): String = { + s"from_protobuf() doesn't support the $name mode. " + + s"Acceptable modes are ${PermissiveMode.name} and ${FailFastMode.name}." + } + + @transient private lazy val nullResultRow: Any = dataType match { + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + for (i <- 0 until st.length) { + resultRow.setNullAt(i) + } + resultRow + + case _ => + null + } + + private def handleException(e: Throwable): Any = { + parseMode match { + case PermissiveMode => + nullResultRow + case FailFastMode => + throw new SparkException( + "Malformed records are detected in record parsing. " + + s"Current parse Mode: ${FailFastMode.name}. To process malformed records as null " + + "result, try setting the option 'mode' as 'PERMISSIVE'.", + e) + case _ => + throw new AnalysisException(unacceptableModeMessage(parseMode.name)) + } + } + + override def nullSafeEval(input: Any): Any = { + val binary = input.asInstanceOf[Array[Byte]] + try { + result = DynamicMessage.parseFrom(messageDescriptor, binary) + val unknownFields = result.getUnknownFields + if (!unknownFields.asMap().isEmpty) { + unknownFields.asMap().keySet().asScala.map { number => + { + if (fieldsNumbers.contains(number)) { + return handleException( + new Throwable(s"Type mismatch encountered for field:" + + s" ${messageDescriptor.getFields.get(number)}")) + } + } + } + } + val deserialized = deserializer.deserialize(result) + assert( + deserialized.isDefined, + "Protobuf deserializer cannot return an empty result because filters are not pushed down") + deserialized.get + } catch { + // There could be multiple possible exceptions here, e.g. java.io.IOException, + // ProtoRuntimeException, ArrayIndexOutOfBoundsException, etc. + // To make it simple, catch all the exceptions here. + case NonFatal(e) => + handleException(e) + } + } + + override def prettyName: String = "from_protobuf" + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val expr = ctx.addReferenceObj("this", this) + nullSafeCodeGen( + ctx, + ev, + eval => { + val result = ctx.freshName("result") + val dt = CodeGenerator.boxedType(dataType) + s""" + $dt $result = ($dt) $expr.nullSafeEval($eval); + if ($result == null) { + ${ev.isNull} = true; + } else { + ${ev.value} = $result; + } + """ + }) + } + + override protected def withNewChildInternal(newChild: Expression): ProtobufDataToCatalyst = + copy(child = newChild) +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala new file mode 100644 index 0000000000000..0403b741ebfa7 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufDeserializer.scala @@ -0,0 +1,357 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import java.util.concurrent.TimeUnit + +import com.google.protobuf.{ByteString, DynamicMessage, Message} +import com.google.protobuf.Descriptors._ +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + +import org.apache.spark.sql.catalyst.{InternalRow, NoopFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeArrayData} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, DateTimeUtils, GenericArrayData} +import org.apache.spark.sql.errors.QueryCompilationErrors +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.ProtoMatchedField +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.toFieldStr +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +private[sql] class ProtobufDeserializer( + rootDescriptor: Descriptor, + rootCatalystType: DataType, + filters: StructFilters) { + + def this(rootDescriptor: Descriptor, rootCatalystType: DataType) = { + this(rootDescriptor, rootCatalystType, new NoopFilters) + } + + private val converter: Any => Option[InternalRow] = + try { + rootCatalystType match { + // A shortcut for empty schema. + case st: StructType if st.isEmpty => + (_: Any) => Some(InternalRow.empty) + + case st: StructType => + val resultRow = new SpecificInternalRow(st.map(_.dataType)) + val fieldUpdater = new RowUpdater(resultRow) + val applyFilters = filters.skipRow(resultRow, _) + val writer = getRecordWriter(rootDescriptor, st, Nil, Nil, applyFilters) + (data: Any) => { + val record = data.asInstanceOf[DynamicMessage] + val skipRow = writer(fieldUpdater, record) + if (skipRow) None else Some(resultRow) + } + } + } catch { + case ise: IncompatibleSchemaException => + throw new IncompatibleSchemaException( + s"Cannot convert Protobuf type ${rootDescriptor.getName} " + + s"to SQL type ${rootCatalystType.sql}.", + ise) + } + + def deserialize(data: Message): Option[InternalRow] = converter(data) + + private def newArrayWriter( + protoField: FieldDescriptor, + protoPath: Seq[String], + catalystPath: Seq[String], + elementType: DataType, + containsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + + val protoElementPath = protoPath :+ "element" + val elementWriter = + newWriter(protoField, elementType, protoElementPath, catalystPath :+ "element") + (updater, ordinal, value) => + val collection = value.asInstanceOf[java.util.Collection[Any]] + val result = createArrayData(elementType, collection.size()) + val elementUpdater = new ArrayDataUpdater(result) + + var i = 0 + val iterator = collection.iterator() + while (iterator.hasNext) { + val element = iterator.next() + if (element == null) { + if (!containsNull) { + throw QueryCompilationErrors.nullableArrayOrMapElementError(protoElementPath) + } else { + elementUpdater.setNullAt(i) + } + } else { + elementWriter(elementUpdater, i, element) + } + i += 1 + } + + updater.set(ordinal, result) + } + + private def newMapWriter( + protoType: FieldDescriptor, + protoPath: Seq[String], + catalystPath: Seq[String], + keyType: DataType, + valueType: DataType, + valueContainsNull: Boolean): (CatalystDataUpdater, Int, Any) => Unit = { + val keyField = protoType.getMessageType.getFields.get(0) + val valueField = protoType.getMessageType.getFields.get(1) + val keyWriter = newWriter(keyField, keyType, protoPath :+ "key", catalystPath :+ "key") + val valueWriter = + newWriter(valueField, valueType, protoPath :+ "value", catalystPath :+ "value") + (updater, ordinal, value) => + if (value != null) { + val messageList = value.asInstanceOf[java.util.List[com.google.protobuf.Message]] + val valueArray = createArrayData(valueType, messageList.size()) + val valueUpdater = new ArrayDataUpdater(valueArray) + val keyArray = createArrayData(keyType, messageList.size()) + val keyUpdater = new ArrayDataUpdater(keyArray) + var i = 0 + messageList.forEach { field => + { + keyWriter(keyUpdater, i, field.getField(keyField)) + if (field.getField(valueField) == null) { + if (!valueContainsNull) { + throw QueryCompilationErrors.nullableArrayOrMapElementError(protoPath) + } else { + valueUpdater.setNullAt(i) + } + } else { + valueWriter(valueUpdater, i, field.getField(valueField)) + } + } + i += 1 + } + updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) + } + } + + /** + * Creates a writer to write Protobuf values to Catalyst values at the given ordinal with the + * given updater. + */ + private def newWriter( + protoType: FieldDescriptor, + catalystType: DataType, + protoPath: Seq[String], + catalystPath: Seq[String]): (CatalystDataUpdater, Int, Any) => Unit = { + val errorPrefix = s"Cannot convert Protobuf ${toFieldStr(protoPath)} to " + + s"SQL ${toFieldStr(catalystPath)} because " + val incompatibleMsg = errorPrefix + + s"schema is incompatible (protoType = ${protoType} ${protoType.toProto.getLabel} " + + s"${protoType.getJavaType} ${protoType.getType}, sqlType = ${catalystType.sql})" + + (protoType.getJavaType, catalystType) match { + + case (null, NullType) => (updater, ordinal, _) => updater.setNullAt(ordinal) + + // TODO: we can avoid boxing if future version of Protobuf provide primitive accessors. + case (BOOLEAN, BooleanType) => + (updater, ordinal, value) => updater.setBoolean(ordinal, value.asInstanceOf[Boolean]) + + case (INT, IntegerType) => + (updater, ordinal, value) => updater.setInt(ordinal, value.asInstanceOf[Int]) + + case (INT, ByteType) => + (updater, ordinal, value) => updater.setByte(ordinal, value.asInstanceOf[Byte]) + + case (INT, ShortType) => + (updater, ordinal, value) => updater.setShort(ordinal, value.asInstanceOf[Short]) + + case (BOOLEAN | INT | FLOAT | DOUBLE | LONG | STRING | ENUM | BYTE_STRING, + ArrayType(dataType: DataType, containsNull)) if protoType.isRepeated => + newArrayWriter(protoType, protoPath, catalystPath, dataType, containsNull) + + case (LONG, LongType) => + (updater, ordinal, value) => updater.setLong(ordinal, value.asInstanceOf[Long]) + + case (FLOAT, FloatType) => + (updater, ordinal, value) => updater.setFloat(ordinal, value.asInstanceOf[Float]) + + case (DOUBLE, DoubleType) => + (updater, ordinal, value) => updater.setDouble(ordinal, value.asInstanceOf[Double]) + + case (STRING, StringType) => + (updater, ordinal, value) => + val str = value match { + case s: String => UTF8String.fromString(s) + } + updater.set(ordinal, str) + + case (BYTE_STRING, BinaryType) => + (updater, ordinal, value) => + val byte_array = value match { + case s: ByteString => s.toByteArray + case _ => throw new Exception("Invalid ByteString format") + } + updater.set(ordinal, byte_array) + + case (MESSAGE, MapType(keyType, valueType, valueContainsNull)) => + newMapWriter(protoType, protoPath, catalystPath, keyType, valueType, valueContainsNull) + + case (MESSAGE, TimestampType) => + (updater, ordinal, value) => + val secondsField = protoType.getMessageType.getFields.get(0) + val nanoSecondsField = protoType.getMessageType.getFields.get(1) + val message = value.asInstanceOf[DynamicMessage] + val seconds = message.getField(secondsField).asInstanceOf[Long] + val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int] + val micros = DateTimeUtils.millisToMicros(seconds * 1000) + updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds)) + + case (MESSAGE, DayTimeIntervalType(startField, endField)) => + (updater, ordinal, value) => + val secondsField = protoType.getMessageType.getFields.get(0) + val nanoSecondsField = protoType.getMessageType.getFields.get(1) + val message = value.asInstanceOf[DynamicMessage] + val seconds = message.getField(secondsField).asInstanceOf[Long] + val nanoSeconds = message.getField(nanoSecondsField).asInstanceOf[Int] + val micros = DateTimeUtils.millisToMicros(seconds * 1000) + updater.setLong(ordinal, micros + TimeUnit.NANOSECONDS.toMicros(nanoSeconds)) + + case (MESSAGE, st: StructType) => + val writeRecord = getRecordWriter( + protoType.getMessageType, + st, + protoPath, + catalystPath, + applyFilters = _ => false) + (updater, ordinal, value) => + val row = new SpecificInternalRow(st) + writeRecord(new RowUpdater(row), value.asInstanceOf[DynamicMessage]) + updater.set(ordinal, row) + + case (MESSAGE, ArrayType(st: StructType, containsNull)) => + newArrayWriter(protoType, protoPath, catalystPath, st, containsNull) + + case (ENUM, StringType) => + (updater, ordinal, value) => updater.set(ordinal, UTF8String.fromString(value.toString)) + + case _ => throw new IncompatibleSchemaException(incompatibleMsg) + } + } + + private def getRecordWriter( + protoType: Descriptor, + catalystType: StructType, + protoPath: Seq[String], + catalystPath: Seq[String], + applyFilters: Int => Boolean): (CatalystDataUpdater, DynamicMessage) => Boolean = { + + val protoSchemaHelper = + new ProtobufUtils.ProtoSchemaHelper(protoType, catalystType, protoPath, catalystPath) + + // TODO revisit validation of protobuf-catalyst fields. + // protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = true) + + var i = 0 + val (validFieldIndexes, fieldWriters) = protoSchemaHelper.matchedFields + .map { case ProtoMatchedField(catalystField, ordinal, protoField) => + val baseWriter = newWriter( + protoField, + catalystField.dataType, + protoPath :+ protoField.getName, + catalystPath :+ catalystField.name) + val fieldWriter = (fieldUpdater: CatalystDataUpdater, value: Any) => { + if (value == null) { + fieldUpdater.setNullAt(ordinal) + } else { + baseWriter(fieldUpdater, ordinal, value) + } + } + i += 1 + (protoField, fieldWriter) + } + .toArray + .unzip + + (fieldUpdater, record) => { + var i = 0 + var skipRow = false + while (i < validFieldIndexes.length && !skipRow) { + val field = validFieldIndexes(i) + val value = if (field.isRepeated || field.hasDefaultValue || record.hasField(field)) { + record.getField(field) + } else null + fieldWriters(i)(fieldUpdater, value) + skipRow = applyFilters(i) + i += 1 + } + skipRow + } + } + + // TODO: All of the code below this line is same between protobuf and avro, it can be shared. + private def createArrayData(elementType: DataType, length: Int): ArrayData = elementType match { + case BooleanType => UnsafeArrayData.fromPrimitiveArray(new Array[Boolean](length)) + case ByteType => UnsafeArrayData.fromPrimitiveArray(new Array[Byte](length)) + case ShortType => UnsafeArrayData.fromPrimitiveArray(new Array[Short](length)) + case IntegerType => UnsafeArrayData.fromPrimitiveArray(new Array[Int](length)) + case LongType => UnsafeArrayData.fromPrimitiveArray(new Array[Long](length)) + case FloatType => UnsafeArrayData.fromPrimitiveArray(new Array[Float](length)) + case DoubleType => UnsafeArrayData.fromPrimitiveArray(new Array[Double](length)) + case _ => new GenericArrayData(new Array[Any](length)) + } + + /** + * A base interface for updating values inside catalyst data structure like `InternalRow` and + * `ArrayData`. + */ + sealed trait CatalystDataUpdater { + def set(ordinal: Int, value: Any): Unit + def setNullAt(ordinal: Int): Unit = set(ordinal, null) + def setBoolean(ordinal: Int, value: Boolean): Unit = set(ordinal, value) + def setByte(ordinal: Int, value: Byte): Unit = set(ordinal, value) + def setShort(ordinal: Int, value: Short): Unit = set(ordinal, value) + def setInt(ordinal: Int, value: Int): Unit = set(ordinal, value) + def setLong(ordinal: Int, value: Long): Unit = set(ordinal, value) + def setDouble(ordinal: Int, value: Double): Unit = set(ordinal, value) + def setFloat(ordinal: Int, value: Float): Unit = set(ordinal, value) + def setDecimal(ordinal: Int, value: Decimal): Unit = set(ordinal, value) + } + + final class RowUpdater(row: InternalRow) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = row.update(ordinal, value) + override def setNullAt(ordinal: Int): Unit = row.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = row.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = row.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = row.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = row.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = row.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = row.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = row.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = + row.setDecimal(ordinal, value, value.precision) + } + + final class ArrayDataUpdater(array: ArrayData) extends CatalystDataUpdater { + override def set(ordinal: Int, value: Any): Unit = array.update(ordinal, value) + override def setNullAt(ordinal: Int): Unit = array.setNullAt(ordinal) + override def setBoolean(ordinal: Int, value: Boolean): Unit = array.setBoolean(ordinal, value) + override def setByte(ordinal: Int, value: Byte): Unit = array.setByte(ordinal, value) + override def setShort(ordinal: Int, value: Short): Unit = array.setShort(ordinal, value) + override def setInt(ordinal: Int, value: Int): Unit = array.setInt(ordinal, value) + override def setLong(ordinal: Int, value: Long): Unit = array.setLong(ordinal, value) + override def setDouble(ordinal: Int, value: Double): Unit = array.setDouble(ordinal, value) + override def setFloat(ordinal: Int, value: Float): Unit = array.setFloat(ordinal, value) + override def setDecimal(ordinal: Int, value: Decimal): Unit = array.update(ordinal, value) + } + +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala new file mode 100644 index 0000000000000..5d9af92c5c077 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/ProtobufSerializer.scala @@ -0,0 +1,267 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{Duration, DynamicMessage, Timestamp} +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} +import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.SpecializedGetters +import org.apache.spark.sql.catalyst.util.{DateTimeUtils, IntervalUtils} +import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.ProtobufUtils.{toFieldStr, ProtoMatchedField} +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +/** + * A serializer to serialize data in catalyst format to data in Protobuf format. + */ +private[sql] class ProtobufSerializer( + rootCatalystType: DataType, + rootDescriptor: Descriptor, + nullable: Boolean) + extends Logging { + + def serialize(catalystData: Any): Any = { + converter.apply(catalystData) + } + + private val converter: Any => Any = { + val baseConverter = + try { + rootCatalystType match { + case st: StructType => + newStructConverter(st, rootDescriptor, Nil, Nil).asInstanceOf[Any => Any] + } + } catch { + case ise: IncompatibleSchemaException => + throw new IncompatibleSchemaException( + s"Cannot convert SQL type ${rootCatalystType.sql} to Protobuf type " + + s"${rootDescriptor.getName}.", + ise) + } + if (nullable) { (data: Any) => + if (data == null) { + null + } else { + baseConverter.apply(data) + } + } else { + baseConverter + } + } + + private type Converter = (SpecializedGetters, Int) => Any + + private def newConverter( + catalystType: DataType, + fieldDescriptor: FieldDescriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): Converter = { + val errorPrefix = s"Cannot convert SQL ${toFieldStr(catalystPath)} " + + s"to Protobuf ${toFieldStr(protoPath)} because " + (catalystType, fieldDescriptor.getJavaType) match { + case (NullType, _) => + (getter, ordinal) => null + case (BooleanType, BOOLEAN) => + (getter, ordinal) => getter.getBoolean(ordinal) + case (ByteType, INT) => + (getter, ordinal) => getter.getByte(ordinal).toInt + case (ShortType, INT) => + (getter, ordinal) => getter.getShort(ordinal).toInt + case (IntegerType, INT) => + (getter, ordinal) => { + getter.getInt(ordinal) + } + case (LongType, LONG) => + (getter, ordinal) => getter.getLong(ordinal) + case (FloatType, FLOAT) => + (getter, ordinal) => getter.getFloat(ordinal) + case (DoubleType, DOUBLE) => + (getter, ordinal) => getter.getDouble(ordinal) + case (StringType, ENUM) => + val enumSymbols: Set[String] = + fieldDescriptor.getEnumType.getValues.asScala.map(e => e.toString).toSet + (getter, ordinal) => + val data = getter.getUTF8String(ordinal).toString + if (!enumSymbols.contains(data)) { + throw new IncompatibleSchemaException( + errorPrefix + + s""""$data" cannot be written since it's not defined in enum """ + + enumSymbols.mkString("\"", "\", \"", "\"")) + } + fieldDescriptor.getEnumType.findValueByName(data) + case (StringType, STRING) => + (getter, ordinal) => { + String.valueOf(getter.getUTF8String(ordinal)) + } + + case (BinaryType, BYTE_STRING) => + (getter, ordinal) => getter.getBinary(ordinal) + + case (DateType, INT) => + (getter, ordinal) => getter.getInt(ordinal) + + case (TimestampType, MESSAGE) => + (getter, ordinal) => + val millis = DateTimeUtils.microsToMillis(getter.getLong(ordinal)) + Timestamp.newBuilder() + .setSeconds((millis / 1000)) + .setNanos(((millis % 1000) * 1000000).toInt) + .build() + + case (ArrayType(et, containsNull), _) => + val elementConverter = + newConverter(et, fieldDescriptor, catalystPath :+ "element", protoPath :+ "element") + (getter, ordinal) => { + val arrayData = getter.getArray(ordinal) + val len = arrayData.numElements() + val result = new Array[Any](len) + var i = 0 + while (i < len) { + if (containsNull && arrayData.isNullAt(i)) { + result(i) = null + } else { + result(i) = elementConverter(arrayData, i) + } + i += 1 + } + // Protobuf writer is expecting a Java Collection, so we convert it into + // `ArrayList` backed by the specified array without data copying. + java.util.Arrays.asList(result: _*) + } + + case (st: StructType, MESSAGE) => + val structConverter = + newStructConverter(st, fieldDescriptor.getMessageType, catalystPath, protoPath) + val numFields = st.length + (getter, ordinal) => structConverter(getter.getStruct(ordinal, numFields)) + + case (MapType(kt, vt, valueContainsNull), MESSAGE) => + var keyField: FieldDescriptor = null + var valueField: FieldDescriptor = null + fieldDescriptor.getMessageType.getFields.asScala.map { field => + field.getName match { + case "key" => + keyField = field + case "value" => + valueField = field + } + } + + val keyConverter = newConverter(kt, keyField, catalystPath :+ "key", protoPath :+ "key") + val valueConverter = + newConverter(vt, valueField, catalystPath :+ "value", protoPath :+ "value") + + (getter, ordinal) => + val mapData = getter.getMap(ordinal) + val len = mapData.numElements() + val list = new java.util.ArrayList[DynamicMessage]() + val keyArray = mapData.keyArray() + val valueArray = mapData.valueArray() + var i = 0 + while (i < len) { + val result = DynamicMessage.newBuilder(fieldDescriptor.getMessageType) + if (valueContainsNull && valueArray.isNullAt(i)) { + result.setField(keyField, keyConverter(keyArray, i)) + result.setField(valueField, valueField.getDefaultValue) + } else { + result.setField(keyField, keyConverter(keyArray, i)) + result.setField(valueField, valueConverter(valueArray, i)) + } + list.add(result.build()) + i += 1 + } + list + + case (DayTimeIntervalType(startField, endField), MESSAGE) => + (getter, ordinal) => + val dayTimeIntervalString = + IntervalUtils.toDayTimeIntervalString(getter.getLong(ordinal) + , ANSI_STYLE, startField, endField) + val calendarInterval = IntervalUtils.fromIntervalString(dayTimeIntervalString) + + val millis = DateTimeUtils.microsToMillis(calendarInterval.microseconds) + val duration = Duration.newBuilder() + .setSeconds((millis / 1000)) + .setNanos(((millis % 1000) * 1000000).toInt) + + if (duration.getSeconds < 0 && duration.getNanos > 0) { + duration.setSeconds(duration.getSeconds + 1) + duration.setNanos(duration.getNanos - 1000000000) + } else if (duration.getSeconds > 0 && duration.getNanos < 0) { + duration.setSeconds(duration.getSeconds - 1) + duration.setNanos(duration.getNanos + 1000000000) + } + duration.build() + + case _ => + throw new IncompatibleSchemaException( + errorPrefix + + s"schema is incompatible (sqlType = ${catalystType.sql}, " + + s"protoType = ${fieldDescriptor.getJavaType})") + } + } + + private def newStructConverter( + catalystStruct: StructType, + descriptor: Descriptor, + catalystPath: Seq[String], + protoPath: Seq[String]): InternalRow => DynamicMessage = { + + val protoSchemaHelper = + new ProtobufUtils.ProtoSchemaHelper(descriptor, catalystStruct, protoPath, catalystPath) + + protoSchemaHelper.validateNoExtraCatalystFields(ignoreNullable = false) + protoSchemaHelper.validateNoExtraRequiredProtoFields() + + val (protoIndices, fieldConverters: Array[Converter]) = protoSchemaHelper.matchedFields + .map { case ProtoMatchedField(catalystField, _, protoField) => + val converter = newConverter( + catalystField.dataType, + protoField, + catalystPath :+ catalystField.name, + protoPath :+ protoField.getName) + (protoField, converter) + } + .toArray + .unzip + + val numFields = catalystStruct.length + row: InternalRow => + val result = DynamicMessage.newBuilder(descriptor) + var i = 0 + while (i < numFields) { + if (row.isNullAt(i)) { + if (!protoIndices(i).isRepeated() && + protoIndices(i).getJavaType() != FieldDescriptor.JavaType.MESSAGE && + protoIndices(i).isRequired()) { + result.setField(protoIndices(i), protoIndices(i).getDefaultValue()) + } + } else { + result.setField(protoIndices(i), fieldConverters(i).apply(row, i)) + } + i += 1 + } + result.build() + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala new file mode 100644 index 0000000000000..283d1ca8c412c --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/functions.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import scala.collection.JavaConverters._ + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.Column + +// scalastyle:off: object.name +object functions { +// scalastyle:on: object.name + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf message name to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_protobuf( + data: Column, + descFilePath: String, + messageName: String, + options: java.util.Map[String, String]): Column = { + new Column( + ProtobufDataToCatalyst(data.expr, descFilePath, messageName, options.asScala.toMap)) + } + + /** + * Converts a binary column of Protobuf format into its corresponding catalyst value. The + * specified schema must match actual schema of the read data, otherwise the behavior is + * undefined: it may fail or return arbitrary result. To deserialize the data with a compatible + * and evolved schema, the expected Protobuf schema can be set via the option protoSchema. + * + * @param data + * the binary column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def from_protobuf(data: Column, descFilePath: String, messageName: String): Column = { + new Column(ProtobufDataToCatalyst(data.expr, descFilePath, messageName, Map.empty)) + } + + /** + * Converts a column into binary of protobuf format. + * + * @param data + * the data column. + * @param descFilePath + * the protobuf descriptor in Message GeneratedMessageV3 format. + * @param messageName + * the protobuf MessageName to look for in descriptorFile. + * @since 3.4.0 + */ + @Experimental + def to_protobuf(data: Column, descFilePath: String, messageName: String): Column = { + new Column(CatalystDataToProtobuf(data.expr, descFilePath, messageName)) + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala new file mode 100644 index 0000000000000..82cdc6b9c5816 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/package.scala @@ -0,0 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql + +package object protobuf { + protected[protobuf] object ScalaReflectionLock +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala new file mode 100644 index 0000000000000..1cece0d7966e5 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufOptions.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf.utils + +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.catalyst.FileSourceOptions +import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, FailFastMode, ParseMode} + +/** + * Options for Protobuf Reader and Writer stored in case insensitive manner. + */ +private[sql] class ProtobufOptions( + @transient val parameters: CaseInsensitiveMap[String], + @transient val conf: Configuration) + extends FileSourceOptions(parameters) + with Logging { + + def this(parameters: Map[String, String], conf: Configuration) = { + this(CaseInsensitiveMap(parameters), conf) + } + + val parseMode: ParseMode = + parameters.get("mode").map(ParseMode.fromString).getOrElse(FailFastMode) +} + +private[sql] object ProtobufOptions { + def apply(parameters: Map[String, String]): ProtobufOptions = { + val hadoopConf = SparkSession.getActiveSession + .map(_.sessionState.newHadoopConf()) + .getOrElse(new Configuration()) + new ProtobufOptions(CaseInsensitiveMap(parameters), hadoopConf) + } +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala new file mode 100644 index 0000000000000..5ad043142a2d2 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/ProtobufUtils.scala @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf.utils + +import java.io.{BufferedInputStream, FileInputStream, IOException} +import java.util.Locale + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{DescriptorProtos, Descriptors, InvalidProtocolBufferException} +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.types._ + +private[sql] object ProtobufUtils extends Logging { + + /** Wrapper for a pair of matched fields, one Catalyst and one corresponding Protobuf field. */ + private[sql] case class ProtoMatchedField( + catalystField: StructField, + catalystPosition: Int, + fieldDescriptor: FieldDescriptor) + + /** + * Helper class to perform field lookup/matching on Protobuf schemas. + * + * This will match `descriptor` against `catalystSchema`, attempting to find a matching field in + * the Protobuf descriptor for each field in the Catalyst schema and vice-versa, respecting + * settings for case sensitivity. The match results can be accessed using the getter methods. + * + * @param descriptor + * The descriptor in which to search for fields. Must be of type Descriptor. + * @param catalystSchema + * The Catalyst schema to use for matching. + * @param protoPath + * The seq of parent field names leading to `protoSchema`. + * @param catalystPath + * The seq of parent field names leading to `catalystSchema`. + */ + class ProtoSchemaHelper( + descriptor: Descriptor, + catalystSchema: StructType, + protoPath: Seq[String], + catalystPath: Seq[String]) { + if (descriptor.getName == null) { + throw new IncompatibleSchemaException( + s"Attempting to treat ${descriptor.getName} as a RECORD, " + + s"but it was: ${descriptor.getContainingType}") + } + + private[this] val protoFieldArray = descriptor.getFields.asScala.toArray + private[this] val fieldMap = descriptor.getFields.asScala + .groupBy(_.getName.toLowerCase(Locale.ROOT)) + .mapValues(_.toSeq) // toSeq needed for scala 2.13 + + /** The fields which have matching equivalents in both Protobuf and Catalyst schemas. */ + val matchedFields: Seq[ProtoMatchedField] = catalystSchema.zipWithIndex.flatMap { + case (sqlField, sqlPos) => + getFieldByName(sqlField.name).map(ProtoMatchedField(sqlField, sqlPos, _)) + } + + /** + * Validate that there are no Catalyst fields which don't have a matching Protobuf field, + * throwing [[IncompatibleSchemaException]] if such extra fields are found. If + * `ignoreNullable` is false, consider nullable Catalyst fields to be eligible to be an extra + * field; otherwise, ignore nullable Catalyst fields when checking for extras. + */ + def validateNoExtraCatalystFields(ignoreNullable: Boolean): Unit = + catalystSchema.fields.foreach { sqlField => + if (getFieldByName(sqlField.name).isEmpty && + (!ignoreNullable || !sqlField.nullable)) { + throw new IncompatibleSchemaException( + s"Cannot find ${toFieldStr(catalystPath :+ sqlField.name)} in Protobuf schema") + } + } + + /** + * Validate that there are no Protobuf fields which don't have a matching Catalyst field, + * throwing [[IncompatibleSchemaException]] if such extra fields are found. Only required + * (non-nullable) fields are checked; nullable fields are ignored. + */ + def validateNoExtraRequiredProtoFields(): Unit = { + val extraFields = protoFieldArray.toSet -- matchedFields.map(_.fieldDescriptor) + extraFields.filterNot(isNullable).foreach { extraField => + throw new IncompatibleSchemaException( + s"Found ${toFieldStr(protoPath :+ extraField.getName())} in Protobuf schema " + + "but there is no match in the SQL schema") + } + } + + /** + * Extract a single field from the contained Protobuf schema which has the desired field name, + * performing the matching with proper case sensitivity according to SQLConf.resolver. + * + * @param name + * The name of the field to search for. + * @return + * `Some(match)` if a matching Protobuf field is found, otherwise `None`. + */ + private[protobuf] def getFieldByName(name: String): Option[FieldDescriptor] = { + + // get candidates, ignoring case of field name + val candidates = fieldMap.getOrElse(name.toLowerCase(Locale.ROOT), Seq.empty) + + // search candidates, taking into account case sensitivity settings + candidates.filter(f => SQLConf.get.resolver(f.getName(), name)) match { + case Seq(protoField) => Some(protoField) + case Seq() => None + case matches => + throw new IncompatibleSchemaException( + s"Searching for '$name' in " + + s"Protobuf schema at ${toFieldStr(protoPath)} gave ${matches.size} matches. " + + s"Candidates: " + matches.map(_.getName()).mkString("[", ", ", "]")) + } + } + } + + def buildDescriptor(descFilePath: String, messageName: String): Descriptor = { + val fileDescriptor: Descriptors.FileDescriptor = parseFileDescriptor(descFilePath) + var result: Descriptors.Descriptor = null; + + for (descriptor <- fileDescriptor.getMessageTypes.asScala) { + if (descriptor.getName().equals(messageName)) { + result = descriptor + } + } + + if (null == result) { + throw new RuntimeException("Unable to locate Message '" + messageName + "' in Descriptor"); + } + result + } + + def parseFileDescriptor(descFilePath: String): Descriptors.FileDescriptor = { + var fileDescriptorSet: DescriptorProtos.FileDescriptorSet = null + try { + val dscFile = new BufferedInputStream(new FileInputStream(descFilePath)) + fileDescriptorSet = DescriptorProtos.FileDescriptorSet.parseFrom(dscFile) + } catch { + case ex: InvalidProtocolBufferException => + // TODO move all the exceptions to core/src/main/resources/error/error-classes.json + throw new RuntimeException("Error parsing descriptor byte[] into Descriptor object", ex) + case ex: IOException => + throw new RuntimeException( + "Error reading Protobuf descriptor file at path: " + + descFilePath, + ex) + } + + val descriptorProto: DescriptorProtos.FileDescriptorProto = fileDescriptorSet.getFile(0) + try { + val fileDescriptor: Descriptors.FileDescriptor = Descriptors.FileDescriptor.buildFrom( + descriptorProto, + new Array[Descriptors.FileDescriptor](0)) + if (fileDescriptor.getMessageTypes().isEmpty()) { + throw new RuntimeException("No MessageTypes returned, " + fileDescriptor.getName()); + } + fileDescriptor + } catch { + case e: Descriptors.DescriptorValidationException => + throw new RuntimeException("Error constructing FileDescriptor", e) + } + } + + /** + * Convert a sequence of hierarchical field names (like `Seq(foo, bar)`) into a human-readable + * string representing the field, like "field 'foo.bar'". If `names` is empty, the string + * "top-level record" is returned. + */ + private[protobuf] def toFieldStr(names: Seq[String]): String = names match { + case Seq() => "top-level record" + case n => s"field '${n.mkString(".")}'" + } + + /** Return true if `fieldDescriptor` is optional. */ + private[protobuf] def isNullable(fieldDescriptor: FieldDescriptor): Boolean = + !fieldDescriptor.isOptional + +} diff --git a/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala new file mode 100644 index 0000000000000..e385b816abe70 --- /dev/null +++ b/connector/protobuf/src/main/scala/org/apache/spark/sql/protobuf/utils/SchemaConverters.scala @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf.utils + +import scala.collection.JavaConverters._ + +import com.google.protobuf.Descriptors.{Descriptor, FieldDescriptor} + +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.protobuf.ScalaReflectionLock +import org.apache.spark.sql.types._ + +@DeveloperApi +object SchemaConverters { + + /** + * Internal wrapper for SQL data type and nullability. + * + * @since 3.4.0 + */ + case class SchemaType(dataType: DataType, nullable: Boolean) + + /** + * Converts an Protobuf schema to a corresponding Spark SQL schema. + * + * @since 3.4.0 + */ + def toSqlType(descriptor: Descriptor): SchemaType = { + toSqlTypeHelper(descriptor) + } + + def toSqlTypeHelper(descriptor: Descriptor): SchemaType = ScalaReflectionLock.synchronized { + SchemaType( + StructType(descriptor.getFields.asScala.flatMap(structFieldFor(_, Set.empty)).toSeq), + nullable = true) + } + + def structFieldFor( + fd: FieldDescriptor, + existingRecordNames: Set[String]): Option[StructField] = { + import com.google.protobuf.Descriptors.FieldDescriptor.JavaType._ + val dataType = fd.getJavaType match { + case INT => Some(IntegerType) + case LONG => Some(LongType) + case FLOAT => Some(FloatType) + case DOUBLE => Some(DoubleType) + case BOOLEAN => Some(BooleanType) + case STRING => Some(StringType) + case BYTE_STRING => Some(BinaryType) + case ENUM => Some(StringType) + case MESSAGE if fd.getMessageType.getName == "Duration" => + Some(DayTimeIntervalType.defaultConcreteType) + case MESSAGE if fd.getMessageType.getName == "Timestamp" => + Some(TimestampType) + case MESSAGE if fd.isRepeated && fd.getMessageType.getOptions.hasMapEntry => + var keyType: DataType = NullType + var valueType: DataType = NullType + fd.getMessageType.getFields.forEach { field => + field.getName match { + case "key" => + keyType = structFieldFor(field, existingRecordNames).get.dataType + case "value" => + valueType = structFieldFor(field, existingRecordNames).get.dataType + } + } + return Option( + StructField( + fd.getName, + MapType(keyType, valueType, valueContainsNull = false).defaultConcreteType, + nullable = false)) + case MESSAGE => + if (existingRecordNames.contains(fd.getFullName)) { + throw new IncompatibleSchemaException(s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |${fd.toString()}""".stripMargin) + } + val newRecordNames = existingRecordNames + fd.getFullName + + Option( + fd.getMessageType.getFields.asScala + .flatMap(structFieldFor(_, newRecordNames.toSet)) + .toSeq) + .filter(_.nonEmpty) + .map(StructType.apply) + case _ => + throw new IncompatibleSchemaException( + s"Cannot convert Protobuf type" + + s" ${fd.getJavaType}") + } + dataType.map(dt => + StructField( + fd.getName, + if (fd.isRepeated) ArrayType(dt, containsNull = false) else dt, + nullable = !fd.isRequired && !fd.isRepeated)) + } + + private[protobuf] class IncompatibleSchemaException(msg: String, ex: Throwable = null) + extends Exception(msg, ex) +} diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc new file mode 100644 index 0000000000000..59255b488a03d --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.desc @@ -0,0 +1,48 @@ + +‰ +Cconnector/protobuf/src/test/resources/protobuf/catalyst_types.protoorg.apache.spark.sql.protobuf") + +BooleanMsg + bool_type (RboolType"+ + +IntegerMsg + +int32_type (R int32Type", + DoubleMsg + double_type (R +doubleType") +FloatMsg + +float_type (R floatType") +BytesMsg + +bytes_type ( R bytesType", + StringMsg + string_type ( R +stringType". +Person +name ( Rname +age (Rage"n +Bad +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 ( Rcol2 +col_3 (Rcol3 +col_4 (Rcol4"q +Actual +col_0 ( Rcol0 +col_1 (Rcol1 +col_2 (Rcol2 +col_3 (Rcol3 +col_4 (Rcol4" + oldConsumer +key ( Rkey"5 + newProducer +key ( Rkey +value (Rvalue"t + newConsumer +key ( Rkey +value (Rvalue= +actual ( 2%.org.apache.spark.sql.protobuf.ActualRactual" + oldProducer +key ( RkeyBB CatalystTypesbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto new file mode 100644 index 0000000000000..54e6bc18df153 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/catalyst_types.proto @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/catalyst_types.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/catalyst_types.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; +option java_outer_classname = "CatalystTypes"; + +message BooleanMsg { + bool bool_type = 1; +} +message IntegerMsg { + int32 int32_type = 1; +} +message DoubleMsg { + double double_type = 1; +} +message FloatMsg { + float float_type = 1; +} +message BytesMsg { + bytes bytes_type = 1; +} +message StringMsg { + string string_type = 1; +} + +message Person { + string name = 1; + int32 age = 2; +} + +message Bad { + bytes col_0 = 1; + double col_1 = 2; + string col_2 = 3; + float col_3 = 4; + int64 col_4 = 5; +} + +message Actual { + string col_0 = 1; + int32 col_1 = 2; + float col_2 = 3; + bool col_3 = 4; + double col_4 = 5; +} + +message oldConsumer { + string key = 1; +} + +message newProducer { + string key = 1; + int32 value = 2; +} + +message newConsumer { + string key = 1; + int32 value = 2; + Actual actual = 3; +} + +message oldProducer { + string key = 1; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.desc b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc new file mode 100644 index 0000000000000..6e3a396727729 Binary files /dev/null and b/connector/protobuf/src/test/resources/protobuf/functions_suite.desc differ diff --git a/connector/protobuf/src/test/resources/protobuf/functions_suite.proto b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto new file mode 100644 index 0000000000000..f38c041b799ec --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/functions_suite.proto @@ -0,0 +1,190 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/functions_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/functions_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; + +option java_outer_classname = "SimpleMessageProtos"; + +message SimpleMessageJavaTypes { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message SimpleMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + uint32 uint32_value = 4; + sint32 sint32_value = 5; + fixed32 fixed32_value = 6; + sfixed32 sfixed32_value = 7; + int64 int64_value = 8; + uint64 uint64_value = 9; + sint64 sint64_value = 10; + fixed64 fixed64_value = 11; + sfixed64 sfixed64_value = 12; + double double_value = 13; + float float_value = 14; + bool bool_value = 15; + bytes bytes_value = 16; +} + +message SimpleMessageRepeated { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + repeated string rstring_value = 3; + repeated int32 rint32_value = 4; + repeated bool rbool_value = 5; + repeated int64 rint64_value = 6; + repeated float rfloat_value = 7; + repeated double rdouble_value = 8; + repeated bytes rbytes_value = 9; + repeated NestedEnum rnested_enum = 10; +} + +message BasicMessage { + int64 id = 1; + string string_value = 2; + int32 int32_value = 3; + int64 int64_value = 4; + double double_value = 5; + float float_value = 6; + bool bool_value = 7; + bytes bytes_value = 8; +} + +message RepeatedMessage { + repeated BasicMessage basic_message = 1; +} + +message SimpleMessageMap { + string key = 1; + string value = 2; + map string_mapdata = 3; + map int32_mapdata = 4; + map uint32_mapdata = 5; + map sint32_mapdata = 6; + map float32_mapdata = 7; + map sfixed32_mapdata = 8; + map int64_mapdata = 9; + map uint64_mapdata = 10; + map sint64_mapdata = 11; + map fixed64_mapdata = 12; + map sfixed64_mapdata = 13; + map double_mapdata = 14; + map float_mapdata = 15; + map bool_mapdata = 16; + map bytes_mapdata = 17; +} + +message BasicEnumMessage { + enum BasicEnum { + NOTHING = 0; + FIRST = 1; + SECOND = 2; + } +} + +message SimpleMessageEnum { + string key = 1; + string value = 2; + enum NestedEnum { + ESTED_NOTHING = 0; + NESTED_FIRST = 1; + NESTED_SECOND = 2; + } + BasicEnumMessage.BasicEnum basic_enum = 3; + NestedEnum nested_enum = 4; +} + + +message OtherExample { + string other = 1; +} + +message IncludedExample { + string included = 1; + OtherExample other = 2; +} + +message MultipleExample { + IncludedExample included_example = 1; +} + +message recursiveA { + string keyA = 1; + recursiveB messageB = 2; +} + +message recursiveB { + string keyB = 1; + recursiveA messageA = 2; +} + +message recursiveC { + string keyC = 1; + recursiveD messageD = 2; +} + +message recursiveD { + string keyD = 1; + repeated recursiveC messageC = 2; +} + +message requiredMsg { + string key = 1; + int32 col_1 = 2; + string col_2 = 3; + int32 col_3 = 4; +} + +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/timestamp.proto +message Timestamp { + int64 seconds = 1; + int32 nanos = 2; +} + +message timeStampMsg { + string key = 1; + Timestamp stmp = 2; +} +// https://github.com/protocolbuffers/protobuf/blob/main/src/google/protobuf/duration.proto +message Duration { + int64 seconds = 1; + int32 nanos = 2; +} + +message durationMsg { + string key = 1; + Duration duration = 2; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.desc b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc new file mode 100644 index 0000000000000..3d1847eecc5c3 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.desc @@ -0,0 +1,27 @@ + +² +Fconnector/protobuf/src/test/resources/protobuf/proto_serde_suite.protoorg.apache.spark.sql.protobuf"D + BasicMessage4 +foo ( 2".org.apache.spark.sql.protobuf.FooRfoo" +Foo +bar (Rbar"' +MissMatchTypeInRoot +foo (Rfoo"T +FieldMissingInProto= +foo ( 2+.org.apache.spark.sql.protobuf.MissingFieldRfoo"& + MissingField +barFoo (RbarFoo"\ +MissMatchTypeInDeepNested? +top ( 2-.org.apache.spark.sql.protobuf.TypeMissNestedRtop"K +TypeMissNested9 +foo ( 2'.org.apache.spark.sql.protobuf.TypeMissRfoo" +TypeMiss +bar (Rbar"_ +FieldMissingInSQLRoot4 +foo ( 2".org.apache.spark.sql.protobuf.FooRfoo +boo (Rboo"O +FieldMissingInSQLNested4 +foo ( 2".org.apache.spark.sql.protobuf.BazRfoo") +Baz +bar (Rbar +baz (RbazBBSimpleMessageProtosbproto3 \ No newline at end of file diff --git a/connector/protobuf/src/test/resources/protobuf/serde_suite.proto b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto new file mode 100644 index 0000000000000..1e3065259aa02 --- /dev/null +++ b/connector/protobuf/src/test/resources/protobuf/serde_suite.proto @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +// To compile and create test class: +// protoc --java_out=connector/protobuf/src/test/resources/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto +// protoc --descriptor_set_out=connector/protobuf/src/test/resources/protobuf/serde_suite.desc --java_out=connector/protobuf/src/test/resources/protobuf/org/apache/spark/sql/protobuf/ connector/protobuf/src/test/resources/protobuf/serde_suite.proto + +syntax = "proto3"; + +package org.apache.spark.sql.protobuf; +option java_outer_classname = "SimpleMessageProtos"; + +/* Clean Message*/ +message BasicMessage { + Foo foo = 1; +} + +message Foo { + int32 bar = 1; +} + +/* Field Type missMatch in root Message*/ +message MissMatchTypeInRoot { + int64 foo = 1; +} + +/* Field bar missing from protobuf and Available in SQL*/ +message FieldMissingInProto { + MissingField foo = 1; +} + +message MissingField { + int64 barFoo = 1; +} + +/* Deep-nested field bar type missMatch Message*/ +message MissMatchTypeInDeepNested { + TypeMissNested top = 1; +} + +message TypeMissNested { + TypeMiss foo = 1; +} + +message TypeMiss { + int64 bar = 1; +} + +/* Field boo missing from SQL root, but available in Protobuf root*/ +message FieldMissingInSQLRoot { + Foo foo = 1; + int32 boo = 2; +} + +/* Field baz missing from SQL nested and available in Protobuf nested*/ +message FieldMissingInSQLNested { + Baz foo = 1; +} + +message Baz { + int32 bar = 1; + int32 baz = 2; +} \ No newline at end of file diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala new file mode 100644 index 0000000000000..b730ebb4fea80 --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufCatalystDataConversionSuite.scala @@ -0,0 +1,212 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf + +import com.google.protobuf.{ByteString, DynamicMessage, Message} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{RandomDataGenerator, Row} +import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, NoopFilters, OrderedFilters, StructFilters} +import org.apache.spark.sql.catalyst.expressions.{ExpressionEvalHelper, GenericInternalRow, Literal} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, MapData} +import org.apache.spark.sql.protobuf.utils.{ProtobufUtils, SchemaConverters} +import org.apache.spark.sql.sources.{EqualTo, Not} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +class ProtobufCatalystDataConversionSuite + extends SparkFunSuite + with SharedSparkSession + with ExpressionEvalHelper { + + private def checkResult( + data: Literal, + descFilePath: String, + messageName: String, + expected: Any): Unit = { + checkEvaluation( + ProtobufDataToCatalyst( + CatalystDataToProtobuf(data, descFilePath, messageName), + descFilePath, + messageName, + Map.empty), + prepareExpectedResult(expected)) + } + + protected def checkUnsupportedRead( + data: Literal, + descFilePath: String, + actualSchema: String, + badSchema: String): Unit = { + + val binary = CatalystDataToProtobuf(data, descFilePath, actualSchema) + + intercept[Exception] { + ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "FAILFAST")).eval() + } + + val expected = { + val expectedSchema = ProtobufUtils.buildDescriptor(descFilePath, badSchema) + SchemaConverters.toSqlType(expectedSchema).dataType match { + case st: StructType => + Row.fromSeq((0 until st.length).map { _ => + null + }) + case _ => null + } + } + + checkEvaluation( + ProtobufDataToCatalyst(binary, descFilePath, badSchema, Map("mode" -> "PERMISSIVE")), + expected) + } + + protected def prepareExpectedResult(expected: Any): Any = expected match { + // Spark byte and short both map to Protobuf int + case b: Byte => b.toInt + case s: Short => s.toInt + case row: GenericInternalRow => InternalRow.fromSeq(row.values.map(prepareExpectedResult)) + case array: GenericArrayData => new GenericArrayData(array.array.map(prepareExpectedResult)) + case map: MapData => + val keys = new GenericArrayData( + map.keyArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + val values = new GenericArrayData( + map.valueArray().asInstanceOf[GenericArrayData].array.map(prepareExpectedResult)) + new ArrayBasedMapData(keys, values) + case other => other + } + + private val testingTypes = Seq( + StructType(StructField("int32_type", IntegerType, nullable = true) :: Nil), + StructType(StructField("double_type", DoubleType, nullable = true) :: Nil), + StructType(StructField("float_type", FloatType, nullable = true) :: Nil), + StructType(StructField("bytes_type", BinaryType, nullable = true) :: Nil), + StructType(StructField("string_type", StringType, nullable = true) :: Nil)) + + private val catalystTypesToProtoMessages: Map[DataType, String] = Map( + IntegerType -> "IntegerMsg", + DoubleType -> "DoubleMsg", + FloatType -> "FloatMsg", + BinaryType -> "BytesMsg", + StringType -> "StringMsg") + + testingTypes.foreach { dt => + val seed = 1 + scala.util.Random.nextInt((1024 - 1) + 1) + val filePath = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + test(s"single $dt with seed $seed") { + val rand = new scala.util.Random(seed) + val data = RandomDataGenerator.forType(dt, rand = rand).get.apply() + val converter = CatalystTypeConverters.createToCatalystConverter(dt) + val input = Literal.create(converter(data), dt) + + checkResult( + input, + filePath, + catalystTypesToProtoMessages(dt.fields(0).dataType), + input.eval()) + } + } + + private def checkDeserialization( + descFilePath: String, + messageName: String, + data: Message, + expected: Option[Any], + filters: StructFilters = new NoopFilters): Unit = { + + val descriptor = ProtobufUtils.buildDescriptor(descFilePath, messageName) + val dataType = SchemaConverters.toSqlType(descriptor).dataType + + val deserializer = new ProtobufDeserializer(descriptor, dataType, filters) + + val dynMsg = DynamicMessage.parseFrom(descriptor, data.toByteArray) + val deserialized = deserializer.deserialize(dynMsg) + expected match { + case None => assert(deserialized.isEmpty) + case Some(d) => + assert(checkResult(d, deserialized.get, dataType, exprNullable = false)) + } + } + + test("Handle unsupported input of message type") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val actualSchema = StructType( + Seq( + StructField("col_0", StringType, nullable = false), + StructField("col_1", IntegerType, nullable = false), + StructField("col_2", FloatType, nullable = false), + StructField("col_3", BooleanType, nullable = false), + StructField("col_4", DoubleType, nullable = false))) + + val seed = scala.util.Random.nextLong() + withClue(s"create random record with seed $seed") { + val data = RandomDataGenerator.randomRow(new scala.util.Random(seed), actualSchema) + val converter = CatalystTypeConverters.createToCatalystConverter(actualSchema) + val input = Literal.create(converter(data), actualSchema) + checkUnsupportedRead(input, testFileDesc, "Actual", "Bad") + } + } + + test("filter push-down to Protobuf deserializer") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val sqlSchema = new StructType() + .add("name", "string") + .add("age", "int") + + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "Person") + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("name"), "Maxim") + .setField(descriptor.findFieldByName("age"), 39) + .build() + + val expectedRow = Some(InternalRow(UTF8String.fromString("Maxim"), 39)) + checkDeserialization(testFileDesc, "Person", dynamicMessage, expectedRow) + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + expectedRow, + new OrderedFilters(Seq(EqualTo("age", 39)), sqlSchema)) + + checkDeserialization( + testFileDesc, + "Person", + dynamicMessage, + None, + new OrderedFilters(Seq(Not(EqualTo("name", "Maxim"))), sqlSchema)) + } + + test("ProtobufDeserializer with binary type") { + + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val bb = java.nio.ByteBuffer.wrap(Array[Byte](97, 48, 53)) + + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "BytesMsg") + + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("bytes_type"), ByteString.copyFrom(bb)) + .build() + + val expected = InternalRow(Array[Byte](97, 48, 53)) + checkDeserialization(testFileDesc, "BytesMsg", dynamicMessage, Some(expected)) + } +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala new file mode 100644 index 0000000000000..4e9bc1c1c287a --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufFunctionsSuite.scala @@ -0,0 +1,615 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.protobuf + +import java.sql.Timestamp +import java.time.Duration + +import scala.collection.JavaConverters._ + +import com.google.protobuf.{ByteString, DynamicMessage} + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.functions.{lit, struct} +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{DayTimeIntervalType, IntegerType, StringType, StructField, StructType, TimestampType} + +class ProtobufFunctionsSuite extends QueryTest with SharedSparkSession with Serializable { + + import testImplicits._ + + val testFileDesc = testFile("protobuf/functions_suite.desc").replace("file:/", "/") + + test("roundtrip in to_protobuf and from_protobuf - struct") { + val df = spark + .range(1, 10) + .select(struct( + $"id", + $"id".cast("string").as("string_value"), + $"id".cast("int").as("int32_value"), + $"id".cast("int").as("uint32_value"), + $"id".cast("int").as("sint32_value"), + $"id".cast("int").as("fixed32_value"), + $"id".cast("int").as("sfixed32_value"), + $"id".cast("long").as("int64_value"), + $"id".cast("long").as("uint64_value"), + $"id".cast("long").as("sint64_value"), + $"id".cast("long").as("fixed64_value"), + $"id".cast("long").as("sfixed64_value"), + $"id".cast("double").as("double_value"), + lit(1202.00).cast(org.apache.spark.sql.types.FloatType).as("float_value"), + lit(true).as("bool_value"), + lit("0".getBytes).as("bytes_value")).as("SimpleMessage")) + val protoStructDF = df.select( + functions.to_protobuf($"SimpleMessage", testFileDesc, "SimpleMessage").as("proto")) + val actualDf = protoStructDF.select( + functions.from_protobuf($"proto", testFileDesc, "SimpleMessage").as("proto.*")) + checkAnswer(actualDf, df) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated") { + val descriptor = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageRepeated") + + val dynamicMessage = DynamicMessage + .newBuilder(descriptor) + .setField(descriptor.findFieldByName("key"), "key") + .setField(descriptor.findFieldByName("value"), "value") + .addRepeatedField(descriptor.findFieldByName("rbool_value"), false) + .addRepeatedField(descriptor.findFieldByName("rbool_value"), true) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092092.654d) + .addRepeatedField(descriptor.findFieldByName("rdouble_value"), 1092093.654d) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10903.0f) + .addRepeatedField(descriptor.findFieldByName("rfloat_value"), 10902.0f) + .addRepeatedField( + descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .addRepeatedField( + descriptor.findFieldByName("rnested_enum"), + descriptor.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageRepeated").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageRepeated").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions + .from_protobuf($"value_to", testFileDesc, "SimpleMessageRepeated") + .as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated Message Once") { + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Repeated Message Twice") { + val repeatedMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "RepeatedMessage") + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage1 = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value1") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer1")) + .build() + val basicMessage2 = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1112L) + .setField(basicMessageDesc.findFieldByName("string_value"), "value2") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12346) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10903.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), false) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer2")) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(repeatedMessageDesc) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage1) + .addRepeatedField(repeatedMessageDesc.findFieldByName("basic_message"), basicMessage2) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "RepeatedMessage").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "RepeatedMessage").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "RepeatedMessage").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Map") { + val messageMapDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageMap") + + val mapStr1 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "string_key") + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value1") + .build() + val mapStr2 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("StringMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("key"), + "string_key") + .setField( + messageMapDesc.findNestedTypeByName("StringMapdataEntry").findFieldByName("value"), + "value2") + .build() + val mapInt64 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("Int64MapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("key"), + 0x90000000000L) + .setField( + messageMapDesc.findNestedTypeByName("Int64MapdataEntry").findFieldByName("value"), + 0x90000000001L) + .build() + val mapInt32 = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("Int32MapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("key"), + 12345) + .setField( + messageMapDesc.findNestedTypeByName("Int32MapdataEntry").findFieldByName("value"), + 54321) + .build() + val mapFloat = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("FloatMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("key"), + "float_key") + .setField( + messageMapDesc.findNestedTypeByName("FloatMapdataEntry").findFieldByName("value"), + 109202.234f) + .build() + val mapDouble = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("DoubleMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("key"), + "double_key") + .setField( + messageMapDesc.findNestedTypeByName("DoubleMapdataEntry").findFieldByName("value"), + 109202.12d) + .build() + val mapBool = DynamicMessage + .newBuilder(messageMapDesc.findNestedTypeByName("BoolMapdataEntry")) + .setField( + messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("key"), + true) + .setField( + messageMapDesc.findNestedTypeByName("BoolMapdataEntry").findFieldByName("value"), + false) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(messageMapDesc) + .setField(messageMapDesc.findFieldByName("key"), "key") + .setField(messageMapDesc.findFieldByName("value"), "value") + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr1) + .addRepeatedField(messageMapDesc.findFieldByName("string_mapdata"), mapStr2) + .addRepeatedField(messageMapDesc.findFieldByName("int64_mapdata"), mapInt64) + .addRepeatedField(messageMapDesc.findFieldByName("int32_mapdata"), mapInt32) + .addRepeatedField(messageMapDesc.findFieldByName("float_mapdata"), mapFloat) + .addRepeatedField(messageMapDesc.findFieldByName("double_mapdata"), mapDouble) + .addRepeatedField(messageMapDesc.findFieldByName("bool_mapdata"), mapBool) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageMap").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageMap").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageMap").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Enum") { + val messageEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "SimpleMessageEnum") + val basicEnumDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicEnumMessage") + + val dynamicMessage = DynamicMessage + .newBuilder(messageEnumDesc) + .setField(messageEnumDesc.findFieldByName("key"), "key") + .setField(messageEnumDesc.findFieldByName("value"), "value") + .setField( + messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("ESTED_NOTHING")) + .setField( + messageEnumDesc.findFieldByName("nested_enum"), + messageEnumDesc.findEnumTypeByName("NestedEnum").findValueByName("NESTED_FIRST")) + .setField( + messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("FIRST")) + .setField( + messageEnumDesc.findFieldByName("basic_enum"), + basicEnumDesc.findEnumTypeByName("BasicEnum").findValueByName("NOTHING")) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "SimpleMessageEnum").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "SimpleMessageEnum").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "SimpleMessageEnum").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("roundtrip in from_protobuf and to_protobuf - Multiple Message") { + val messageMultiDesc = ProtobufUtils.buildDescriptor(testFileDesc, "MultipleExample") + val messageIncludeDesc = ProtobufUtils.buildDescriptor(testFileDesc, "IncludedExample") + val messageOtherDesc = ProtobufUtils.buildDescriptor(testFileDesc, "OtherExample") + + val otherMessage = DynamicMessage + .newBuilder(messageOtherDesc) + .setField(messageOtherDesc.findFieldByName("other"), "other value") + .build() + + val includeMessage = DynamicMessage + .newBuilder(messageIncludeDesc) + .setField(messageIncludeDesc.findFieldByName("included"), "included value") + .setField(messageIncludeDesc.findFieldByName("other"), otherMessage) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(messageMultiDesc) + .setField(messageMultiDesc.findFieldByName("included_example"), includeMessage) + .build() + + val df = Seq(dynamicMessage.toByteArray).toDF("value") + val fromProtoDF = df.select( + functions.from_protobuf($"value", testFileDesc, "MultipleExample").as("value_from")) + val toProtoDF = fromProtoDF.select( + functions.to_protobuf($"value_from", testFileDesc, "MultipleExample").as("value_to")) + val toFromProtoDF = toProtoDF.select( + functions.from_protobuf($"value_to", testFileDesc, "MultipleExample").as("value_to_from")) + checkAnswer(fromProtoDF.select($"value_from.*"), toFromProtoDF.select($"value_to_from.*")) + } + + test("Handle recursive fields in Protobuf schema, A->B->A") { + val schemaA = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveA") + val schemaB = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveB") + + val messageBForA = DynamicMessage + .newBuilder(schemaB) + .setField(schemaB.findFieldByName("keyB"), "key") + .build() + + val messageA = DynamicMessage + .newBuilder(schemaA) + .setField(schemaA.findFieldByName("keyA"), "key") + .setField(schemaA.findFieldByName("messageB"), messageBForA) + .build() + + val messageB = DynamicMessage + .newBuilder(schemaB) + .setField(schemaB.findFieldByName("keyB"), "key") + .setField(schemaB.findFieldByName("messageA"), messageA) + .build() + + val df = Seq(messageB.toByteArray).toDF("messageB") + + val e = intercept[IncompatibleSchemaException] { + df.select( + functions.from_protobuf($"messageB", testFileDesc, "recursiveB").as("messageFromProto")) + .show() + } + val expectedMessage = s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |org.apache.spark.sql.protobuf.recursiveB.messageA""".stripMargin + assert(e.getMessage == expectedMessage) + } + + test("Handle recursive fields in Protobuf schema, C->D->Array(C)") { + val schemaC = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveC") + val schemaD = ProtobufUtils.buildDescriptor(testFileDesc, "recursiveD") + + val messageDForC = DynamicMessage + .newBuilder(schemaD) + .setField(schemaD.findFieldByName("keyD"), "key") + .build() + + val messageC = DynamicMessage + .newBuilder(schemaC) + .setField(schemaC.findFieldByName("keyC"), "key") + .setField(schemaC.findFieldByName("messageD"), messageDForC) + .build() + + val messageD = DynamicMessage + .newBuilder(schemaD) + .setField(schemaD.findFieldByName("keyD"), "key") + .addRepeatedField(schemaD.findFieldByName("messageC"), messageC) + .build() + + val df = Seq(messageD.toByteArray).toDF("messageD") + + val e = intercept[IncompatibleSchemaException] { + df.select( + functions.from_protobuf($"messageD", testFileDesc, "recursiveD").as("messageFromProto")) + .show() + } + val expectedMessage = + s""" + |Found recursive reference in Protobuf schema, which can not be processed by Spark: + |org.apache.spark.sql.protobuf.recursiveD.messageC""".stripMargin + assert(e.getMessage == expectedMessage) + } + + test("Handle extra fields : oldProducer -> newConsumer") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val oldProducer = ProtobufUtils.buildDescriptor(testFileDesc, "oldProducer") + val newConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "newConsumer") + + val oldProducerMessage = DynamicMessage + .newBuilder(oldProducer) + .setField(oldProducer.findFieldByName("key"), "key") + .build() + + val df = Seq(oldProducerMessage.toByteArray).toDF("oldProducerData") + val fromProtoDf = df.select( + functions + .from_protobuf($"oldProducerData", testFileDesc, "newConsumer") + .as("fromProto")) + + val toProtoDf = fromProtoDf.select( + functions + .to_protobuf($"fromProto", testFileDesc, "newConsumer") + .as("toProto")) + + val toProtoDfToFromProtoDf = toProtoDf.select( + functions + .from_protobuf($"toProto", testFileDesc, "newConsumer") + .as("toProtoToFromProto")) + + val actualFieldNames = + toProtoDfToFromProtoDf.select("toProtoToFromProto.*").schema.fields.toSeq.map(f => f.name) + newConsumer.getFields.asScala.map { f => + { + assert(actualFieldNames.contains(f.getName)) + + } + } + assert( + toProtoDfToFromProtoDf.select("toProtoToFromProto.value").take(1).toSeq(0).get(0) == null) + assert( + toProtoDfToFromProtoDf.select("toProtoToFromProto.actual.*").take(1).toSeq(0).get(0) == null) + } + + test("Handle extra fields : newProducer -> oldConsumer") { + val testFileDesc = testFile("protobuf/catalyst_types.desc").replace("file:/", "/") + val newProducer = ProtobufUtils.buildDescriptor(testFileDesc, "newProducer") + val oldConsumer = ProtobufUtils.buildDescriptor(testFileDesc, "oldConsumer") + + val newProducerMessage = DynamicMessage + .newBuilder(newProducer) + .setField(newProducer.findFieldByName("key"), "key") + .setField(newProducer.findFieldByName("value"), 1) + .build() + + val df = Seq(newProducerMessage.toByteArray).toDF("newProducerData") + val fromProtoDf = df.select( + functions + .from_protobuf($"newProducerData", testFileDesc, "oldConsumer") + .as("oldConsumerProto")) + + val expectedFieldNames = oldConsumer.getFields.asScala.map(f => f.getName) + fromProtoDf.select("oldConsumerProto.*").schema.fields.toSeq.map { f => + { + assert(expectedFieldNames.contains(f.name)) + } + } + } + + test("roundtrip in to_protobuf and from_protobuf - with nulls") { + val schema = StructType( + StructField("requiredMsg", + StructType( + StructField("key", StringType, nullable = false) :: + StructField("col_1", IntegerType, nullable = true) :: + StructField("col_2", StringType, nullable = false) :: + StructField("col_3", IntegerType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", null, "value2", null)) + )), + schema + ) + val toProtobuf = inputDf.select( + functions.to_protobuf($"requiredMsg", testFileDesc, "requiredMsg") + .as("to_proto")) + + val binary = toProtobuf.take(1).toSeq(0).get(0).asInstanceOf[Array[Byte]] + + val messageDescriptor = ProtobufUtils.buildDescriptor(testFileDesc, "requiredMsg") + val actualMessage = DynamicMessage.parseFrom(messageDescriptor, binary) + + assert(actualMessage.getField(messageDescriptor.findFieldByName("key")) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_2")) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_1")) == 0) + assert(actualMessage.getField(messageDescriptor.findFieldByName("col_3")) == 0) + + val fromProtoDf = toProtobuf.select( + functions.from_protobuf($"to_proto", testFileDesc, "requiredMsg") as 'from_proto) + + assert(fromProtoDf.select("from_proto.key").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_2").take(1).toSeq(0).get(0) + == inputDf.select("requiredMsg.col_2").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("from_proto.col_1").take(1).toSeq(0).get(0) == null) + assert(fromProtoDf.select("from_proto.col_3").take(1).toSeq(0).get(0) == null) + } + + test("from_protobuf filter to_protobuf") { + val basicMessageDesc = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val basicMessage = DynamicMessage + .newBuilder(basicMessageDesc) + .setField(basicMessageDesc.findFieldByName("id"), 1111L) + .setField(basicMessageDesc.findFieldByName("string_value"), "slam") + .setField(basicMessageDesc.findFieldByName("int32_value"), 12345) + .setField(basicMessageDesc.findFieldByName("int64_value"), 0x90000000000L) + .setField(basicMessageDesc.findFieldByName("double_value"), 10000000000.0d) + .setField(basicMessageDesc.findFieldByName("float_value"), 10902.0f) + .setField(basicMessageDesc.findFieldByName("bool_value"), true) + .setField( + basicMessageDesc.findFieldByName("bytes_value"), + ByteString.copyFromUtf8("ProtobufDeserializer")) + .build() + + val df = Seq(basicMessage.toByteArray).toDF("value") + val resultFrom = df + .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + val resultToFrom = resultFrom + .select(functions.to_protobuf($"sample", testFileDesc, "BasicMessage") as 'value) + .select(functions.from_protobuf($"value", testFileDesc, "BasicMessage") as 'sample) + .where("sample.string_value == \"slam\"") + + assert(resultFrom.except(resultToFrom).isEmpty) + } + + test("Handle TimestampType between to_protobuf and from_protobuf") { + val schema = StructType( + StructField("timeStampMsg", + StructType( + StructField("key", StringType, nullable = true) :: + StructField("stmp", TimestampType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", Timestamp.valueOf("2016-05-09 10:12:43.999"))) + )), + schema + ) + + val toProtoDf = inputDf + .select(functions.to_protobuf($"timeStampMsg", testFileDesc, "timeStampMsg") as 'to_proto) + + val fromProtoDf = toProtoDf + .select(functions.from_protobuf($"to_proto", testFileDesc, "timeStampMsg") as 'timeStampMsg) + fromProtoDf.show(truncate = false) + + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("timeStampMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0) + === inputDf.select("timeStampMsg.stmp").take(1).toSeq(0).get(0)) + } + + test("Handle DayTimeIntervalType between to_protobuf and from_protobuf") { + val schema = StructType( + StructField("durationMsg", + StructType( + StructField("key", StringType, nullable = true) :: + StructField("duration", + DayTimeIntervalType.defaultConcreteType, nullable = true) :: Nil + ), + nullable = true + ) :: Nil + ) + + val inputDf = spark.createDataFrame( + spark.sparkContext.parallelize(Seq( + Row(Row("key1", + Duration.ofDays(1).plusHours(2).plusMinutes(3).plusSeconds(4) + )) + )), + schema + ) + + val toProtoDf = inputDf + .select(functions.to_protobuf($"durationMsg", testFileDesc, "durationMsg") as 'to_proto) + + val fromProtoDf = toProtoDf + .select(functions.from_protobuf($"to_proto", testFileDesc, "durationMsg") as 'durationMsg) + + val actualFields = fromProtoDf.schema.fields.toList + val expectedFields = inputDf.schema.fields.toList + + assert(actualFields.size === expectedFields.size) + assert(actualFields === expectedFields) + assert(fromProtoDf.select("durationMsg.key").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.key").take(1).toSeq(0).get(0)) + assert(fromProtoDf.select("durationMsg.duration").take(1).toSeq(0).get(0) + === inputDf.select("durationMsg.duration").take(1).toSeq(0).get(0)) + + } +} diff --git a/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala new file mode 100644 index 0000000000000..37c59743e7714 --- /dev/null +++ b/connector/protobuf/src/test/scala/org/apache/spark/sql/protobuf/ProtobufSerdeSuite.scala @@ -0,0 +1,224 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.protobuf + +import com.google.protobuf.Descriptors.Descriptor +import com.google.protobuf.DynamicMessage + +import org.apache.spark.sql.catalyst.NoopFilters +import org.apache.spark.sql.protobuf.utils.ProtobufUtils +import org.apache.spark.sql.protobuf.utils.SchemaConverters.IncompatibleSchemaException +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.{IntegerType, StructType} + +/** + * Tests for [[ProtobufSerializer]] and [[ProtobufDeserializer]] with a more specific focus on + * those classes. + */ +class ProtobufSerdeSuite extends SharedSparkSession { + + import ProtoSerdeSuite._ + import ProtoSerdeSuite.MatchType._ + + val testFileDesc = testFile("protobuf/serde_suite.desc").replace("file:/", "/") + + test("Test basic conversion") { + withFieldMatchType { fieldMatch => + val (top, nest) = fieldMatch match { + case BY_NAME => ("foo", "bar") + } + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "BasicMessage") + + val dynamicMessageFoo = DynamicMessage + .newBuilder(protoFile.getFile.findMessageTypeByName("Foo")) + .setField(protoFile.getFile.findMessageTypeByName("Foo").findFieldByName("bar"), 10902) + .build() + + val dynamicMessage = DynamicMessage + .newBuilder(protoFile) + .setField(protoFile.findFieldByName("foo"), dynamicMessageFoo) + .build() + + val serializer = Serializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + val deserializer = Deserializer.create(CATALYST_STRUCT, protoFile, fieldMatch) + + assert( + serializer.serialize(deserializer.deserialize(dynamicMessage).get) === dynamicMessage) + } + } + + test("Fail to convert with field type mismatch") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInRoot") + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + "Cannot convert Protobuf field 'foo' to SQL field 'foo' because schema is incompatible " + + s"(protoType = org.apache.spark.sql.protobuf.MissMatchTypeInRoot.foo " + + s"LABEL_OPTIONAL LONG INT64, sqlType = ${CATALYST_STRUCT.head.dataType.sql})".stripMargin) + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + s"Cannot convert SQL field 'foo' to Protobuf field 'foo' because schema is incompatible " + + s"""(sqlType = ${CATALYST_STRUCT.head.dataType.sql}, protoType = LONG)""") + } + } + + test("Fail to convert with missing nested Protobuf fields for serializer") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInProto") + + val nonnullCatalyst = new StructType() + .add("foo", new StructType().add("bar", IntegerType, nullable = false)) + + // serialize fails whether or not 'bar' is nullable + val byNameMsg = "Cannot find field 'foo.bar' in Protobuf schema" + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg) + assertFailedConversionMessage(protoFile, Serializer, BY_NAME, byNameMsg, nonnullCatalyst) + } + + test("Fail to convert with deeply nested field type mismatch") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "MissMatchTypeInDeepNested") + val catalyst = new StructType().add("top", CATALYST_STRUCT) + + withFieldMatchType { fieldMatch => + assertFailedConversionMessage( + protoFile, + Deserializer, + fieldMatch, + s"Cannot convert Protobuf field 'top.foo.bar' to SQL field 'top.foo.bar' because schema " + + s"is incompatible (protoType = org.apache.spark.sql.protobuf.TypeMiss.bar " + + s"LABEL_OPTIONAL LONG INT64, sqlType = INT)".stripMargin, + catalyst) + + assertFailedConversionMessage( + protoFile, + Serializer, + fieldMatch, + "Cannot convert SQL field 'top.foo.bar' to Protobuf field 'top.foo.bar' because schema " + + """is incompatible (sqlType = INT, protoType = LONG)""", + catalyst) + } + } + + test("Fail to convert with missing Catalyst fields") { + val protoFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLRoot") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoFile, + Serializer, + BY_NAME, + "Found field 'boo' in Protobuf schema but there is no match in the SQL schema") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoFile, _)) + + val protoNestedFile = ProtobufUtils.buildDescriptor(testFileDesc, "FieldMissingInSQLNested") + + // serializing with extra fails if extra field is missing in SQL Schema + assertFailedConversionMessage( + protoNestedFile, + Serializer, + BY_NAME, + "Found field 'foo.baz' in Protobuf schema but there is no match in the SQL schema") + + /* deserializing should work regardless of whether the extra field is missing + in SQL Schema or not */ + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + withFieldMatchType(Deserializer.create(CATALYST_STRUCT, protoNestedFile, _)) + } + + /** + * Attempt to convert `catalystSchema` to `protoSchema` (or vice-versa if `deserialize` is + * true), assert that it fails, and assert that the _cause_ of the thrown exception has a + * message matching `expectedCauseMessage`. + */ + private def assertFailedConversionMessage( + protoSchema: Descriptor, + serdeFactory: SerdeFactory[_], + fieldMatchType: MatchType, + expectedCauseMessage: String, + catalystSchema: StructType = CATALYST_STRUCT): Unit = { + val e = intercept[IncompatibleSchemaException] { + serdeFactory.create(catalystSchema, protoSchema, fieldMatchType) + } + val expectMsg = serdeFactory match { + case Deserializer => + s"Cannot convert Protobuf type ${protoSchema.getName} to SQL type ${catalystSchema.sql}." + case Serializer => + s"Cannot convert SQL type ${catalystSchema.sql} to Protobuf type ${protoSchema.getName}." + } + + assert(e.getMessage === expectMsg) + assert(e.getCause.getMessage === expectedCauseMessage) + } + + def withFieldMatchType(f: MatchType => Unit): Unit = { + MatchType.values.foreach { fieldMatchType => + withClue(s"fieldMatchType == $fieldMatchType") { + f(fieldMatchType) + } + } + } +} + +object ProtoSerdeSuite { + + val CATALYST_STRUCT = + new StructType().add("foo", new StructType().add("bar", IntegerType)) + + /** + * Specifier for type of field matching to be used for easy creation of tests that do by-name + * field matching. + */ + object MatchType extends Enumeration { + type MatchType = Value + val BY_NAME = Value + } + + import MatchType._ + + /** + * Specifier for type of serde to be used for easy creation of tests that do both serialization + * and deserialization. + */ + sealed trait SerdeFactory[T] { + def create(sqlSchema: StructType, descriptor: Descriptor, fieldMatchType: MatchType): T + } + + object Serializer extends SerdeFactory[ProtobufSerializer] { + override def create( + sql: StructType, + descriptor: Descriptor, + matchType: MatchType): ProtobufSerializer = new ProtobufSerializer(sql, descriptor, false) + } + + object Deserializer extends SerdeFactory[ProtobufDeserializer] { + override def create( + sql: StructType, + descriptor: Descriptor, + matchType: MatchType): ProtobufDeserializer = + new ProtobufDeserializer(descriptor, sql, new NoopFilters) + } +} diff --git a/pom.xml b/pom.xml index f82546e4f3e9c..7258f970bab7c 100644 --- a/pom.xml +++ b/pom.xml @@ -101,6 +101,7 @@ connector/kafka-0-10-sql connector/avro connector/connect + connector/protobuf diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 1de8bc6a47ded..15fa3a3143b60 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -45,8 +45,8 @@ object BuildCommons { private val buildLocation = file(".").getAbsoluteFile.getParentFile - val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro) = Seq( - "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro" + val sqlProjects@Seq(catalyst, sql, hive, hiveThriftServer, tokenProviderKafka010, sqlKafka010, avro, protobuf) = Seq( + "catalyst", "sql", "hive", "hive-thriftserver", "token-provider-kafka-0-10", "sql-kafka-0-10", "avro", "protobuf" ).map(ProjectRef(buildLocation, _)) val streamingProjects@Seq(streaming, streamingKafka010) = @@ -59,7 +59,7 @@ object BuildCommons { ) = Seq( "core", "graphx", "mllib", "mllib-local", "repl", "network-common", "network-shuffle", "launcher", "unsafe", "tags", "sketch", "kvstore" - ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) + ).map(ProjectRef(buildLocation, _)) ++ sqlProjects ++ streamingProjects ++ Seq(connect) ++ Seq(protobuf) val optionallyEnabledProjects@Seq(kubernetes, mesos, yarn, sparkGangliaLgpl, streamingKinesisAsl, @@ -390,7 +390,7 @@ object SparkBuild extends PomBuild { val mimaProjects = allProjects.filterNot { x => Seq( spark, hive, hiveThriftServer, repl, networkCommon, networkShuffle, networkYarn, - unsafe, tags, tokenProviderKafka010, sqlKafka010, connect + unsafe, tags, tokenProviderKafka010, sqlKafka010, connect, protobuf ).contains(x) } @@ -433,6 +433,9 @@ object SparkBuild extends PomBuild { enable(SparkConnect.settings)(connect) + /* Connector/proto settings */ + enable(SparkProtobuf.settings)(protobuf) + // SPARK-14738 - Remove docker tests from main Spark build // enable(DockerIntegrationTests.settings)(dockerIntegrationTests) @@ -662,6 +665,48 @@ object SparkConnect { ) } +object SparkProtobuf { + + import BuildCommons.protoVersion + + private val shadePrefix = "org.sparkproject.spark-protobuf" + val shadeJar = taskKey[Unit]("Shade the Jars") + + lazy val settings = Seq( + // Setting version for the protobuf compiler. This has to be propagated to every sub-project + // even if the project is not using it. + PB.protocVersion := BuildCommons.protoVersion, + + // For some reason the resolution from the imported Maven build does not work for some + // of these dependendencies that we need to shade later on. + libraryDependencies ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion % "protobuf" + ), + + dependencyOverrides ++= Seq( + "com.google.protobuf" % "protobuf-java" % protoVersion + ), + + (Compile / PB.targets) := Seq( + PB.gens.java -> (Compile / sourceManaged).value, + ), + + (assembly / test) := false, + + (assembly / logLevel) := Level.Info, + + (assembly / assemblyShadeRules) := Seq( + ShadeRule.rename("com.google.protobuf.**" -> "org.sparkproject.spark-protobuf.protobuf.@1").inAll, + ), + + (assembly / assemblyMergeStrategy) := { + case m if m.toLowerCase(Locale.ROOT).endsWith("manifest.mf") => MergeStrategy.discard + // Drop all proto files that are not needed as artifacts of the build. + case m if m.toLowerCase(Locale.ROOT).endsWith(".proto") => MergeStrategy.discard + case _ => MergeStrategy.first + }, + ) +} object Unsafe { lazy val settings = Seq( // This option is needed to suppress warnings from sun.misc.Unsafe usage @@ -1107,10 +1152,10 @@ object Unidoc { (ScalaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf), (JavaUnidoc / unidoc / unidocProjectFilter) := inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, kubernetes, - yarn, tags, streamingKafka010, sqlKafka010, connect), + yarn, tags, streamingKafka010, sqlKafka010, connect, protobuf), (ScalaUnidoc / unidoc / unidocAllClasspaths) := { ignoreClasspaths((ScalaUnidoc / unidoc / unidocAllClasspaths).value) @@ -1196,6 +1241,7 @@ object CopyDependencies { // produce the shaded Jar which happens automatically in the case of Maven. // Later, when the dependencies are copied, we manually copy the shaded Jar only. val fid = (LocalProject("connect") / assembly).value + val fidProtobuf = (LocalProject("protobuf")/assembly).value (Compile / dependencyClasspath).value.map(_.data) .filter { jar => jar.isFile() } @@ -1208,6 +1254,9 @@ object CopyDependencies { if (jar.getName.contains("spark-connect") && !SbtPomKeys.profiles.value.contains("noshade-connect")) { Files.copy(fid.toPath, destJar.toPath) + } else if (jar.getName.contains("spark-protobuf") && + !SbtPomKeys.profiles.value.contains("noshade-protobuf")) { + Files.copy(fid.toPath, destJar.toPath) } else { Files.copy(jar.toPath(), destJar.toPath()) }