diff --git a/docs/sql-migration-guide-upgrade.md b/docs/sql-migration-guide-upgrade.md index 25cd541190919..55838e773e4b1 100644 --- a/docs/sql-migration-guide-upgrade.md +++ b/docs/sql-migration-guide-upgrade.md @@ -27,6 +27,8 @@ displayTitle: Spark SQL Upgrading Guide - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more. + - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined. + ## Upgrading From Spark SQL 2.3 to 2.4 - In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below. diff --git a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index 272e7d5b388d9..4e2224b058a0a 100644 --- a/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/external/avro/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.avro -import java.math.{BigDecimal} +import java.math.BigDecimal import java.nio.ByteBuffer import scala.collection.JavaConverters._ @@ -218,6 +218,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) { i += 1 } + // The Avro map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) case (UNION, _) => diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index 286ef219a69e9..f98e550e39da8 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2656,11 +2656,11 @@ def map_concat(*cols): >>> from pyspark.sql.functions import map_concat >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) - +--------------------------------+ - |map3 | - +--------------------------------+ - |[1 -> a, 2 -> b, 3 -> c, 1 -> d]| - +--------------------------------+ + +------------------------+ + |map3 | + +------------------------+ + |[1 -> d, 2 -> b, 3 -> c]| + +------------------------+ """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java index f17441dfccb6d..a0833a6df8bbd 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/UnsafeMapData.java @@ -28,6 +28,9 @@ * Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head * to indicate the number of bytes of the unsafe key array. * [unsafe key array numBytes] [unsafe key array] [unsafe value array] + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. */ // TODO: Use a more efficient format which doesn't depend on unsafe array. public final class UnsafeMapData extends MapData { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala index 6f5fbdd79e668..93df73ab1eaf6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala @@ -431,12 +431,6 @@ object CatalystTypeConverters { map, (key: Any) => convertToCatalyst(key), (value: Any) => convertToCatalyst(value)) - case (keys: Array[_], values: Array[_]) => - // case for mapdata with duplicate keys - new ArrayBasedMapData( - new GenericArrayData(keys.map(convertToCatalyst)), - new GenericArrayData(values.map(convertToCatalyst)) - ) case other => other } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala index 274d75e680f03..e49c10be6be4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/InternalRow.scala @@ -125,22 +125,36 @@ object InternalRow { * actually takes a `SpecializedGetters` input because it can be generalized to other classes * that implements `SpecializedGetters` (e.g., `ArrayData`) too. */ - def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match { - case BooleanType => (input, ordinal) => input.getBoolean(ordinal) - case ByteType => (input, ordinal) => input.getByte(ordinal) - case ShortType => (input, ordinal) => input.getShort(ordinal) - case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) - case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) - case FloatType => (input, ordinal) => input.getFloat(ordinal) - case DoubleType => (input, ordinal) => input.getDouble(ordinal) - case StringType => (input, ordinal) => input.getUTF8String(ordinal) - case BinaryType => (input, ordinal) => input.getBinary(ordinal) - case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) - case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) - case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) - case _: ArrayType => (input, ordinal) => input.getArray(ordinal) - case _: MapType => (input, ordinal) => input.getMap(ordinal) - case u: UserDefinedType[_] => getAccessor(u.sqlType) - case _ => (input, ordinal) => input.get(ordinal, dataType) + def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = { + val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match { + case BooleanType => (input, ordinal) => input.getBoolean(ordinal) + case ByteType => (input, ordinal) => input.getByte(ordinal) + case ShortType => (input, ordinal) => input.getShort(ordinal) + case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal) + case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal) + case FloatType => (input, ordinal) => input.getFloat(ordinal) + case DoubleType => (input, ordinal) => input.getDouble(ordinal) + case StringType => (input, ordinal) => input.getUTF8String(ordinal) + case BinaryType => (input, ordinal) => input.getBinary(ordinal) + case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal) + case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale) + case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size) + case _: ArrayType => (input, ordinal) => input.getArray(ordinal) + case _: MapType => (input, ordinal) => input.getMap(ordinal) + case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable) + case _ => (input, ordinal) => input.get(ordinal, dt) + } + + if (nullable) { + (getter, index) => { + if (getter.isNullAt(index)) { + null + } else { + getValueNullSafe(getter, index) + } + } + } else { + getValueNullSafe + } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index 77582e10f9ff2..ea8c369ee49ed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -34,15 +34,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean) override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]" - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) // Use special getter for primitive types (for UnsafeRow) override def eval(input: InternalRow): Any = { - if (nullable && input.isNullAt(ordinal)) { - null - } else { - accessor(input, ordinal) - } + accessor(input, ordinal) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala index 43116743e9952..fa8e38acd522d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala @@ -546,33 +546,25 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres override def nullable: Boolean = children.exists(_.nullable) + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def eval(input: InternalRow): Any = { - val maps = children.map(_.eval(input)) + val maps = children.map(_.eval(input).asInstanceOf[MapData]) if (maps.contains(null)) { return null } - val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray()) - val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray()) - val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements()) + val numElements = maps.foldLeft(0L)((sum, ad) => sum + ad.numElements()) if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) { throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " + s"elements due to exceeding the map size limit " + s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.") } - val finalKeyArray = new Array[AnyRef](numElements.toInt) - val finalValueArray = new Array[AnyRef](numElements.toInt) - var position = 0 - for (i <- keyArrayDatas.indices) { - val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType) - val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType) - Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length) - Array.copy(valueArray, 0, finalValueArray, position, valueArray.length) - position += keyArray.length - } - new ArrayBasedMapData(new GenericArrayData(finalKeyArray), - new GenericArrayData(finalValueArray)) + for (map <- maps) { + mapBuilder.putAll(map.keyArray(), map.valueArray()) + } + mapBuilder.build() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { @@ -581,16 +573,7 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val valueType = dataType.valueType val argsName = ctx.freshName("args") val hasNullName = ctx.freshName("hasNull") - val mapDataClass = classOf[MapData].getName - val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName - val arrayDataClass = classOf[ArrayData].getName - - val init = - s""" - |$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}]; - |boolean ${ev.isNull}, $hasNullName = false; - |$mapDataClass ${ev.value} = null; - """.stripMargin + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map { case ((m, true), i) => @@ -613,10 +596,10 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres """.stripMargin } - val codes = ctx.splitExpressionsWithCurrentInputs( + val prepareMaps = ctx.splitExpressionsWithCurrentInputs( expressions = assignments, funcName = "getMapConcatInputs", - extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil, + extraArguments = (s"MapData[]", argsName) :: ("boolean", hasNullName) :: Nil, returnType = "boolean", makeSplitFunction = body => s""" @@ -646,34 +629,34 @@ case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpres val mapMerge = s""" - |${ev.isNull} = $hasNullName; - |if (!${ev.isNull}) { - | $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}]; - | $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}]; - | long $numElementsName = 0; - | for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { - | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); - | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); - | $numElementsName += $argsName[$idxName].numElements(); - | } - | if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | throw new RuntimeException("Unsuccessful attempt to concat maps with " + - | $numElementsName + " elements due to exceeding the map size limit " + - | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); - | } - | $arrayDataClass $finKeysName = $keyConcat($keyArgsName, - | (int) $numElementsName); - | $arrayDataClass $finValsName = $valueConcat($valArgsName, - | (int) $numElementsName); - | ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName); + |ArrayData[] $keyArgsName = new ArrayData[${mapCodes.size}]; + |ArrayData[] $valArgsName = new ArrayData[${mapCodes.size}]; + |long $numElementsName = 0; + |for (int $idxName = 0; $idxName < $argsName.length; $idxName++) { + | $keyArgsName[$idxName] = $argsName[$idxName].keyArray(); + | $valArgsName[$idxName] = $argsName[$idxName].valueArray(); + | $numElementsName += $argsName[$idxName].numElements(); |} + |if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { + | throw new RuntimeException("Unsuccessful attempt to concat maps with " + + | $numElementsName + " elements due to exceeding the map size limit " + + | "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}."); + |} + |ArrayData $finKeysName = $keyConcat($keyArgsName, (int) $numElementsName); + |ArrayData $finValsName = $valueConcat($valArgsName, (int) $numElementsName); + |${ev.value} = $builderTerm.from($finKeysName, $finValsName); """.stripMargin ev.copy( code = code""" - |$init - |$codes - |$mapMerge + |MapData[] $argsName = new MapData[${mapCodes.size}]; + |boolean $hasNullName = false; + |$prepareMaps + |boolean ${ev.isNull} = $hasNullName; + |MapData ${ev.value} = null; + |if (!$hasNullName) { + | $mapMerge + |} """.stripMargin) } @@ -751,171 +734,44 @@ case class MapFromEntries(child: Expression) extends UnaryExpression { s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.") } + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override protected def nullSafeEval(input: Any): Any = { - val arrayData = input.asInstanceOf[ArrayData] - val numEntries = arrayData.numElements() + val entries = input.asInstanceOf[ArrayData] + val numEntries = entries.numElements() var i = 0 - if(nullEntries) { + if (nullEntries) { while (i < numEntries) { - if (arrayData.isNullAt(i)) return null + if (entries.isNullAt(i)) return null i += 1 } } - val keyArray = new Array[AnyRef](numEntries) - val valueArray = new Array[AnyRef](numEntries) + i = 0 while (i < numEntries) { - val entry = arrayData.getStruct(i, 2) - val key = entry.get(0, dataType.keyType) - if (key == null) { - throw new RuntimeException("The first field from a struct (key) can't be null.") - } - keyArray.update(i, key) - val value = entry.get(1, dataType.valueType) - valueArray.update(i, value) + mapBuilder.put(entries.getStruct(i, 2)) i += 1 } - ArrayBasedMapData(keyArray, valueArray) + mapBuilder.build() } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, c => { val numEntries = ctx.freshName("numEntries") - val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType) - val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) - val code = if (isKeyPrimitive && isValuePrimitive) { - genCodeForPrimitiveElements(ctx, c, ev.value, numEntries) - } else { - genCodeForAnyElements(ctx, c, ev.value, numEntries) - } + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) + val i = ctx.freshName("idx") ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) { s""" |final int $numEntries = $c.numElements(); - |$code + |for (int $i = 0; $i < $numEntries; $i++) { + | $builderTerm.put($c.getStruct($i, 2)); + |} + |${ev.value} = $builderTerm.build(); """.stripMargin } }) } - private def genCodeForAssignmentLoop( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String, - keyAssignment: (String, String) => String, - valueAssignment: (String, String) => String): String = { - val entry = ctx.freshName("entry") - val i = ctx.freshName("idx") - - val nullKeyCheck = if (dataTypeDetails.get._2) { - s""" - |if ($entry.isNullAt(0)) { - | throw new RuntimeException("The first field from a struct (key) can't be null."); - |} - """.stripMargin - } else { - "" - } - - s""" - |for (int $i = 0; $i < $numEntries; $i++) { - | InternalRow $entry = $childVariable.getStruct($i, 2); - | $nullKeyCheck - | ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)} - | ${valueAssignment(entry, i)} - |} - """.stripMargin - } - - private def genCodeForPrimitiveElements( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String): String = { - val byteArraySize = ctx.freshName("byteArraySize") - val keySectionSize = ctx.freshName("keySectionSize") - val valueSectionSize = ctx.freshName("valueSectionSize") - val data = ctx.freshName("byteArray") - val unsafeMapData = ctx.freshName("unsafeMapData") - val keyArrayData = ctx.freshName("keyArrayData") - val valueArrayData = ctx.freshName("valueArrayData") - - val baseOffset = Platform.BYTE_ARRAY_OFFSET - val keySize = dataType.keyType.defaultSize - val valueSize = dataType.valueType.defaultSize - val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)" - val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)" - - val keyAssignment = (key: String, idx: String) => - CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key) - val valueAssignment = (entry: String, idx: String) => - CodeGenerator.createArrayAssignment( - valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull) - val assignmentLoop = genCodeForAssignmentLoop( - ctx, - childVariable, - mapData, - numEntries, - keyAssignment, - valueAssignment - ) - - s""" - |final long $keySectionSize = $kByteSize; - |final long $valueSectionSize = $vByteSize; - |final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize; - |if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) { - | ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)} - |} else { - | final byte[] $data = new byte[(int)$byteArraySize]; - | UnsafeMapData $unsafeMapData = new UnsafeMapData(); - | Platform.putLong($data, $baseOffset, $keySectionSize); - | Platform.putLong($data, ${baseOffset + 8}, $numEntries); - | Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries); - | $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize); - | ArrayData $keyArrayData = $unsafeMapData.keyArray(); - | ArrayData $valueArrayData = $unsafeMapData.valueArray(); - | $assignmentLoop - | $mapData = $unsafeMapData; - |} - """.stripMargin - } - - private def genCodeForAnyElements( - ctx: CodegenContext, - childVariable: String, - mapData: String, - numEntries: String): String = { - val keys = ctx.freshName("keys") - val values = ctx.freshName("values") - val mapDataClass = classOf[ArrayBasedMapData].getName() - - val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType) - val valueAssignment = (entry: String, idx: String) => { - val value = CodeGenerator.getValue(entry, dataType.valueType, "1") - if (dataType.valueContainsNull && isValuePrimitive) { - s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;" - } else { - s"$values[$idx] = $value;" - } - } - val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;" - val assignmentLoop = genCodeForAssignmentLoop( - ctx, - childVariable, - mapData, - numEntries, - keyAssignment, - valueAssignment) - - s""" - |final Object[] $keys = new Object[$numEntries]; - |final Object[] $values = new Object[$numEntries]; - |$assignmentLoop - |$mapData = $mapDataClass.apply($keys, $values); - """.stripMargin - } - override def prettyName: String = "map_from_entries" } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala index 6b77996789f1a..4e722c9237a90 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypeCreator.scala @@ -24,8 +24,6 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.Platform -import org.apache.spark.unsafe.array.ByteArrayMethods import org.apache.spark.unsafe.types.UTF8String /** @@ -62,7 +60,7 @@ case class CreateArray(children: Seq[Expression]) extends Expression { override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val et = dataType.elementType val (allocation, assigns, arrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, et, children, false, "createArray") + GenArrayData.genCodeToCreateArrayData(ctx, et, children, "createArray") ev.copy( code = code"${allocation}${assigns}", value = JavaCode.variable(arrayData, dataType), @@ -79,7 +77,6 @@ private [sql] object GenArrayData { * @param ctx a [[CodegenContext]] * @param elementType data type of underlying array elements * @param elementsExpr concatenated set of [[Expression]] for each element of an underlying array - * @param isMapKey if true, throw an exception when the element is null * @param functionName string to include in the error message * @return (array allocation, concatenated assignments to each array elements, arrayData name) */ @@ -87,7 +84,6 @@ private [sql] object GenArrayData { ctx: CodegenContext, elementType: DataType, elementsExpr: Seq[Expression], - isMapKey: Boolean, functionName: String): (String, String, String) = { val arrayDataName = ctx.freshName("arrayData") val numElements = s"${elementsExpr.length}L" @@ -103,15 +99,9 @@ private [sql] object GenArrayData { val assignment = if (!expr.nullable) { setArrayElement } else { - val isNullAssignment = if (!isMapKey) { - s"$arrayDataName.setNullAt($i);" - } else { - "throw new RuntimeException(\"Cannot use null as map key!\");" - } - s""" |if (${eval.isNull}) { - | $isNullAssignment + | $arrayDataName.setNullAt($i); |} else { | $setArrayElement |} @@ -165,7 +155,7 @@ case class CreateMap(children: Seq[Expression]) extends Expression { } } - override def dataType: MapType = { + override lazy val dataType: MapType = { MapType( keyType = TypeCoercion.findCommonTypeDifferentOnlyInNullFlags(keys.map(_.dataType)) .getOrElse(StringType), @@ -176,32 +166,33 @@ case class CreateMap(children: Seq[Expression]) extends Expression { override def nullable: Boolean = false + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def eval(input: InternalRow): Any = { - val keyArray = keys.map(_.eval(input)).toArray - if (keyArray.contains(null)) { - throw new RuntimeException("Cannot use null as map key!") + var i = 0 + while (i < keys.length) { + mapBuilder.put(keys(i).eval(input), values(i).eval(input)) + i += 1 } - val valueArray = values.map(_.eval(input)).toArray - new ArrayBasedMapData(new GenericArrayData(keyArray), new GenericArrayData(valueArray)) + mapBuilder.build() } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { - val mapClass = classOf[ArrayBasedMapData].getName val MapType(keyDt, valueDt, _) = dataType val (allocationKeyData, assignKeys, keyArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys, true, "createMap") + GenArrayData.genCodeToCreateArrayData(ctx, keyDt, keys, "createMap") val (allocationValueData, assignValues, valueArrayData) = - GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values, false, "createMap") + GenArrayData.genCodeToCreateArrayData(ctx, valueDt, values, "createMap") + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) val code = code""" - final boolean ${ev.isNull} = false; $allocationKeyData $assignKeys $allocationValueData $assignValues - final MapData ${ev.value} = new $mapClass($keyArrayData, $valueArrayData); + final MapData ${ev.value} = $builderTerm.from($keyArrayData, $valueArrayData); """ - ev.copy(code = code) + ev.copy(code = code, isNull = FalseLiteral) } override def prettyName: String = "map" @@ -234,53 +225,25 @@ case class MapFromArrays(left: Expression, right: Expression) } } - override def dataType: DataType = { + override def dataType: MapType = { MapType( keyType = left.dataType.asInstanceOf[ArrayType].elementType, valueType = right.dataType.asInstanceOf[ArrayType].elementType, valueContainsNull = right.dataType.asInstanceOf[ArrayType].containsNull) } + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) + override def nullSafeEval(keyArray: Any, valueArray: Any): Any = { val keyArrayData = keyArray.asInstanceOf[ArrayData] val valueArrayData = valueArray.asInstanceOf[ArrayData] - if (keyArrayData.numElements != valueArrayData.numElements) { - throw new RuntimeException("The given two arrays should have the same length") - } - val leftArrayType = left.dataType.asInstanceOf[ArrayType] - if (leftArrayType.containsNull) { - var i = 0 - while (i < keyArrayData.numElements) { - if (keyArrayData.isNullAt(i)) { - throw new RuntimeException("Cannot use null as map key!") - } - i += 1 - } - } - new ArrayBasedMapData(keyArrayData.copy(), valueArrayData.copy()) + mapBuilder.from(keyArrayData.copy(), valueArrayData.copy()) } override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { nullSafeCodeGen(ctx, ev, (keyArrayData, valueArrayData) => { - val arrayBasedMapData = classOf[ArrayBasedMapData].getName - val leftArrayType = left.dataType.asInstanceOf[ArrayType] - val keyArrayElemNullCheck = if (!leftArrayType.containsNull) "" else { - val i = ctx.freshName("i") - s""" - |for (int $i = 0; $i < $keyArrayData.numElements(); $i++) { - | if ($keyArrayData.isNullAt($i)) { - | throw new RuntimeException("Cannot use null as map key!"); - | } - |} - """.stripMargin - } - s""" - |if ($keyArrayData.numElements() != $valueArrayData.numElements()) { - | throw new RuntimeException("The given two arrays should have the same length"); - |} - |$keyArrayElemNullCheck - |${ev.value} = new $arrayBasedMapData($keyArrayData.copy(), $valueArrayData.copy()); - """.stripMargin + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) + s"${ev.value} = $builderTerm.from($keyArrayData.copy(), $valueArrayData.copy());" }) } @@ -488,28 +451,25 @@ case class StringToMap(text: Expression, pairDelim: Expression, keyValueDelim: E } } + private lazy val mapBuilder = new ArrayBasedMapBuilder(StringType, StringType) + override def nullSafeEval( inputString: Any, stringDelimiter: Any, keyValueDelimiter: Any): Any = { val keyValues = inputString.asInstanceOf[UTF8String].split(stringDelimiter.asInstanceOf[UTF8String], -1) - - val iterator = new Iterator[(UTF8String, UTF8String)] { - var index = 0 - val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] - - override def hasNext: Boolean = { - keyValues.length > index - } - - override def next(): (UTF8String, UTF8String) = { - val keyValueArray = keyValues(index).split(keyValueDelimiterUTF8String, 2) - index += 1 - (keyValueArray(0), if (keyValueArray.length < 2) null else keyValueArray(1)) - } + val keyValueDelimiterUTF8String = keyValueDelimiter.asInstanceOf[UTF8String] + + var i = 0 + while (i < keyValues.length) { + val keyValueArray = keyValues(i).split(keyValueDelimiterUTF8String, 2) + val key = keyValueArray(0) + val value = if (keyValueArray.length < 2) null else keyValueArray(1) + mapBuilder.put(key, value) + i += 1 } - ArrayBasedMapData(iterator, keyValues.size, identity, identity) + mapBuilder.build() } override def prettyName: String = "str_to_map" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala index 8b31021866220..a8639d29f964d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/higherOrderFunctions.scala @@ -512,7 +512,7 @@ case class TransformKeys( @transient lazy val MapType(keyType, valueType, valueContainsNull) = argument.dataType - override def dataType: DataType = MapType(function.dataType, valueType, valueContainsNull) + override def dataType: MapType = MapType(function.dataType, valueType, valueContainsNull) override def checkInputDataTypes(): TypeCheckResult = { TypeUtils.checkForMapKeyType(function.dataType) @@ -525,6 +525,7 @@ case class TransformKeys( @transient lazy val LambdaFunction( _, (keyVar: NamedLambdaVariable) :: (valueVar: NamedLambdaVariable) :: Nil, _) = function + private lazy val mapBuilder = new ArrayBasedMapBuilder(dataType.keyType, dataType.valueType) override def nullSafeEval(inputRow: InternalRow, argumentValue: Any): Any = { val map = argumentValue.asInstanceOf[MapData] @@ -534,13 +535,10 @@ case class TransformKeys( keyVar.value.set(map.keyArray().get(i, keyVar.dataType)) valueVar.value.set(map.valueArray().get(i, valueVar.dataType)) val result = functionForEval.eval(inputRow) - if (result == null) { - throw new RuntimeException("Cannot use null as map key!") - } resultKeys.update(i, result) i += 1 } - new ArrayBasedMapData(resultKeys, map.valueArray()) + mapBuilder.from(resultKeys, map.valueArray()) } override def prettyName: String = "transform_keys" diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 59c897b6a53ce..8182730feb4b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -587,17 +587,13 @@ case class LambdaVariable( dataType: DataType, nullable: Boolean = true) extends LeafExpression with NonSQLExpression { - private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType) + private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable) // Interpreted execution of `LambdaVariable` always get the 0-index element from input row. override def eval(input: InternalRow): Any = { assert(input.numFields == 1, "The input row of interpreted LambdaVariable should have only 1 field.") - if (nullable && input.isNullAt(0)) { - null - } else { - accessor(input, 0) - } + accessor(input, 0) } override def genCode(ctx: CodegenContext): ExprCode = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala index 773ff5a7a4013..92517aac053b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/json/JacksonParser.scala @@ -367,6 +367,8 @@ class JacksonParser( values += fieldConverter.apply(parser) } + // The JSON map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. ArrayBasedMapData(keys.toArray, values.toArray) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala new file mode 100644 index 0000000000000..e7cd61655dc9a --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -0,0 +1,120 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.types._ + +/** + * A builder of [[ArrayBasedMapData]], which fails if a null map key is detected, and removes + * duplicated map keys w.r.t. the last wins policy. + */ +class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Serializable { + assert(!keyType.existsRecursively(_.isInstanceOf[MapType]), "key of map cannot be/contain map") + assert(keyType != NullType, "map key cannot be null type.") + + private lazy val keyToIndex = keyType match { + // Binary type data is `byte[]`, which can't use `==` to check equality. + case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] => + new java.util.HashMap[Any, Int]() + case _ => + // for complex types, use interpreted ordering to be able to compare unsafe data with safe + // data, e.g. UnsafeRow vs GenericInternalRow. + new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + } + + // TODO: specialize it + private lazy val keys = mutable.ArrayBuffer.empty[Any] + private lazy val values = mutable.ArrayBuffer.empty[Any] + + private lazy val keyGetter = InternalRow.getAccessor(keyType) + private lazy val valueGetter = InternalRow.getAccessor(valueType) + + def put(key: Any, value: Any): Unit = { + if (key == null) { + throw new RuntimeException("Cannot use null as map key.") + } + + val index = keyToIndex.getOrDefault(key, -1) + if (index == -1) { + keyToIndex.put(key, values.length) + keys.append(key) + values.append(value) + } else { + // Overwrite the previous value, as the policy is last wins. + values(index) = value + } + } + + // write a 2-field row, the first field is key and the second field is value. + def put(entry: InternalRow): Unit = { + if (entry.isNullAt(0)) { + throw new RuntimeException("Cannot use null as map key.") + } + put(keyGetter(entry, 0), valueGetter(entry, 1)) + } + + def putAll(keyArray: ArrayData, valueArray: ArrayData): Unit = { + if (keyArray.numElements() != valueArray.numElements()) { + throw new RuntimeException( + "The key array and value array of MapData must have the same length.") + } + + var i = 0 + while (i < keyArray.numElements()) { + put(keyGetter(keyArray, i), valueGetter(valueArray, i)) + i += 1 + } + } + + private def reset(): Unit = { + keyToIndex.clear() + keys.clear() + values.clear() + } + + /** + * Builds the result [[ArrayBasedMapData]] and reset this builder to free up the resources. The + * builder becomes fresh afterward and is ready to take input and build another map. + */ + def build(): ArrayBasedMapData = { + val map = new ArrayBasedMapData( + new GenericArrayData(keys.toArray), new GenericArrayData(values.toArray)) + reset() + map + } + + /** + * Builds a [[ArrayBasedMapData]] from the given key and value array and reset this builder. The + * builder becomes fresh afterward and is ready to take input and build another map. + */ + def from(keyArray: ArrayData, valueArray: ArrayData): ArrayBasedMapData = { + assert(keyToIndex.isEmpty, "'from' can only be called with a fresh ArrayBasedMapBuilder.") + putAll(keyArray, valueArray) + if (keyToIndex.size == keyArray.numElements()) { + // If there is no duplicated map keys, creates the MapData with the input key and value array, + // as they might already in unsafe format and are more efficient. + reset() + new ArrayBasedMapData(keyArray, valueArray) + } else { + build() + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala index 91b3139443696..0989af26b8c12 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapData.scala @@ -19,6 +19,12 @@ package org.apache.spark.sql.catalyst.util import java.util.{Map => JavaMap} +/** + * A simple `MapData` implementation which is backed by 2 arrays. + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + */ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) extends MapData { require(keyArray.numElements() == valueArray.numElements()) @@ -83,6 +89,9 @@ object ArrayBasedMapData { * Creates a [[ArrayBasedMapData]] by applying the given converters over * each (key -> value) pair from the given iterator * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + * * @param iterator Input iterator * @param size Number of elements * @param keyConverter This function is applied over all the keys extracted from the @@ -108,6 +117,12 @@ object ArrayBasedMapData { ArrayBasedMapData(keys, values) } + /** + * Creates a [[ArrayBasedMapData]] from a key and value array. + * + * Note that, user is responsible to guarantee that the key array does not have duplicated + * elements, otherwise the behavior is undefined. + */ def apply(keys: Array[_], values: Array[_]): ArrayBasedMapData = { new ArrayBasedMapData(new GenericArrayData(keys), new GenericArrayData(values)) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala index 4da8ce05fe8a3..ebbf241088f80 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayData.scala @@ -172,11 +172,7 @@ abstract class ArrayData extends SpecializedGetters with Serializable { val values = new Array[T](size) var i = 0 while (i < size) { - if (isNullAt(i)) { - values(i) = null.asInstanceOf[T] - } else { - values(i) = accessor(this, i).asInstanceOf[T] - } + values(i) = accessor(this, i).asInstanceOf[T] i += 1 } values @@ -187,11 +183,7 @@ abstract class ArrayData extends SpecializedGetters with Serializable { val accessor = InternalRow.getAccessor(elementType) var i = 0 while (i < size) { - if (isNullAt(i)) { - f(i, null) - } else { - f(i, accessor(this, i)) - } + f(i, accessor(this, i)) i += 1 } } @@ -208,11 +200,7 @@ class ArrayDataIndexedSeq[T](arrayData: ArrayData, dataType: DataType) extends I override def apply(idx: Int): T = if (0 <= idx && idx < arrayData.numElements()) { - if (arrayData.isNullAt(idx)) { - null.asInstanceOf[T] - } else { - accessor(arrayData, idx).asInstanceOf[T] - } + accessor(arrayData, idx).asInstanceOf[T] } else { throw new IndexOutOfBoundsException( s"Index $idx must be between 0 and the length of the ArrayData.") diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index d2edb2f24688d..bed8547dbc83d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -114,13 +114,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val m1 = Literal.create(create_map("c" -> "3", "a" -> "4"), MapType(StringType, StringType, valueContainsNull = false)) val m2 = Literal.create(create_map("d" -> "4", "e" -> "5"), MapType(StringType, StringType)) - val m3 = Literal.create(create_map("a" -> "1", "b" -> "2"), MapType(StringType, StringType)) + val m3 = Literal.create(create_map("f" -> "1", "g" -> "2"), MapType(StringType, StringType)) val m4 = Literal.create(create_map("a" -> null, "c" -> "3"), MapType(StringType, StringType)) val m5 = Literal.create(create_map("a" -> 1, "b" -> 2), MapType(StringType, IntegerType)) - val m6 = Literal.create(create_map("a" -> null, "c" -> 3), MapType(StringType, IntegerType)) + val m6 = Literal.create(create_map("c" -> null, "d" -> 3), MapType(StringType, IntegerType)) val m7 = Literal.create(create_map(List(1, 2) -> 1, List(3, 4) -> 2), MapType(ArrayType(IntegerType), IntegerType)) - val m8 = Literal.create(create_map(List(5, 6) -> 3, List(1, 2) -> 4), + val m8 = Literal.create(create_map(List(5, 6) -> 3, List(7, 8) -> 4), MapType(ArrayType(IntegerType), IntegerType)) val m9 = Literal.create(create_map(1 -> "1", 2 -> "2"), MapType(IntegerType, StringType, valueContainsNull = false)) @@ -134,57 +134,33 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) - // overlapping maps - checkEvaluation(MapConcat(Seq(m0, m1)), - ( - Array("a", "b", "c", "a"), // keys - Array("1", "2", "3", "4") // values - ) - ) + // overlapping maps should remove duplicated map keys w.r.t. last win policy. + checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3")) // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5")) // 3 maps - checkEvaluation(MapConcat(Seq(m0, m1, m2)), - ( - Array("a", "b", "c", "a", "d", "e"), // keys - Array("1", "2", "3", "4", "4", "5") // values - ) - ) + checkEvaluation(MapConcat(Seq(m0, m2, m3)), + create_map("a" -> "1", "b" -> "2", "d" -> "4", "e" -> "5", "f" -> "1", "g" -> "2")) // null reference values - checkEvaluation(MapConcat(Seq(m3, m4)), - ( - Array("a", "b", "a", "c"), // keys - Array("1", "2", null, "3") // values - ) - ) + checkEvaluation(MapConcat(Seq(m2, m4)), + create_map("d" -> "4", "e" -> "5", "a" -> null, "c" -> "3")) // null primitive values checkEvaluation(MapConcat(Seq(m5, m6)), - ( - Array("a", "b", "a", "c"), // keys - Array(1, 2, null, 3) // values - ) - ) + create_map("a" -> 1, "b" -> 2, "c" -> null, "d" -> 3)) // keys that are primitive checkEvaluation(MapConcat(Seq(m9, m10)), - ( - Array(1, 2, 3, 4), // keys - Array("1", "2", "3", "4") // values - ) - ) + create_map(1 -> "1", 2 -> "2", 3 -> "3", 4 -> "4")) - // keys that are arrays, with overlap + // keys that are arrays checkEvaluation(MapConcat(Seq(m7, m8)), - ( - Array(List(1, 2), List(3, 4), List(5, 6), List(1, 2)), // keys - Array(1, 2, 3, 4) // values - ) - ) + create_map(List(1, 2) -> 1, List(3, 4) -> 2, List(5, 6) -> 3, List(7, 8) -> 4)) + // both keys and value are primitive and valueContainsNull = false checkEvaluation(MapConcat(Seq(m11, m12)), create_map(1 -> 2, 3 -> 4, 5 -> 6)) @@ -205,15 +181,14 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapConcat(Seq.empty), Map.empty) // force split expressions for input in generated code - val expectedKeys = Array.fill(65)(Seq("a", "b")).flatten ++ Array("d", "e") - val expectedValues = Array.fill(65)(Seq("1", "2")).flatten ++ Array("4", "5") - checkEvaluation(MapConcat( - Seq( - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, - m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m0, m2 - )), - (expectedKeys, expectedValues)) + val expectedKeys = (1 to 65).map(_.toString) + val expectedValues = (1 to 65).map(_.toString) + checkEvaluation( + MapConcat( + expectedKeys.zip(expectedValues).map { + case (k, v) => Literal.create(create_map(k -> v), MapType(StringType, StringType)) + }), + create_map(expectedKeys.zip(expectedValues): _*)) // argument checking assert(MapConcat(Seq(m0, m1)).checkInputDataTypes().isSuccess) @@ -248,7 +223,7 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper ArrayType(IntegerType, containsNull = true), ArrayType(StringType, containsNull = true), valueContainsNull = true)) - checkEvaluation(mapConcat, Map( + checkEvaluation(mapConcat, create_map( Seq(1, 2) -> Seq("a", "b"), Seq(3, 4, null) -> Seq("c", "d", null), Seq(6) -> null)) @@ -282,7 +257,9 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper val ai1 = Literal.create(Seq(row(1, null), row(2, 20), row(3, null)), aiType) val ai2 = Literal.create(Seq.empty, aiType) val ai3 = Literal.create(null, aiType) + // The map key is duplicated val ai4 = Literal.create(Seq(row(1, 10), row(1, 20)), aiType) + // The map key is null val ai5 = Literal.create(Seq(row(1, 10), row(null, 20)), aiType) val ai6 = Literal.create(Seq(null, row(2, 20), null), aiType) @@ -290,10 +267,12 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) - checkEvaluation(MapKeys(MapFromEntries(ai4)), Seq(1, 1)) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20)) + // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), - "The first field from a struct (key) can't be null.") + "Cannot use null as map key") checkEvaluation(MapFromEntries(ai6), null) // Non-primitive-type keys and values @@ -310,13 +289,13 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) - checkEvaluation(MapKeys(MapFromEntries(as4)), Seq("a", "a")) - checkEvaluation(MapFromEntries(as6), null) - + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb")) // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), - "The first field from a struct (key) can't be null.") + "Cannot use null as map key") + checkEvaluation(MapFromEntries(as6), null) // map key can't be map val structOfMap = row(create_map(1 -> 1), 1) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index d95f42e04e37c..dc60464815043 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -183,6 +183,11 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), "Cannot use null as map key") + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))), + create_map(1 -> 3)) + // ArrayType map key and value val map = CreateMap(Seq( Literal.create(intSeq, ArrayType(IntegerType, containsNull = false)), @@ -243,12 +248,18 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { MapFromArrays(intWithNullArray, strArray), "Cannot use null as map key") + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + MapFromArrays( + Literal.create(Seq(1, 1), ArrayType(IntegerType)), + Literal.create(Seq(2, 3), ArrayType(IntegerType))), + create_map(1 -> 3)) + // map key can't be map val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) val map = MapFromArrays( Literal.create(arrayOfMap, ArrayType(MapType(IntegerType, StringType))), - Literal.create(Seq(1), ArrayType(IntegerType)) - ) + Literal.create(Seq(1), ArrayType(IntegerType))) map.checkInputDataTypes() match { case TypeCheckResult.TypeCheckSuccess => fail("should not allow map as map key") case TypeCheckResult.TypeCheckFailure(msg) => @@ -356,6 +367,11 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m5 = Map("a" -> null) checkEvaluation(new StringToMap(s5), m5) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + new StringToMap(Literal("a:1,b:2,a:3")), + create_map("a" -> "3", "b" -> "2")) + // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) assert(new StringToMap(Literal(null)).checkInputDataTypes().isFailure) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index 66bf18af95799..03fb75e330c66 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -330,8 +330,8 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformKeys(transformKeys(ai0, plusOne), plusValue), create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) - checkEvaluation(transformKeys(ai0, modKey), - ArrayBasedMapData(Array(1, 2, 0, 1), Array(1, 2, 3, 4))) + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3)) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( @@ -467,16 +467,13 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper .bind(validateBinding) } - val mii0 = Literal.create(Map(1 -> 10, 2 -> 20, 3 -> 30), + val mii0 = Literal.create(create_map(1 -> 10, 2 -> 20, 3 -> 30), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii1 = Literal.create(Map(1 -> -1, 2 -> -2, 4 -> -4), + val mii1 = Literal.create(create_map(1 -> -1, 2 -> -2, 4 -> -4), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii2 = Literal.create(Map(1 -> null, 2 -> -2, 3 -> null), + val mii2 = Literal.create(create_map(1 -> null, 2 -> -2, 3 -> null), MapType(IntegerType, IntegerType, valueContainsNull = true)) val mii3 = Literal.create(Map(), MapType(IntegerType, IntegerType, valueContainsNull = false)) - val mii4 = MapFromArrays( - Literal.create(Seq(2, 2), ArrayType(IntegerType, false)), - Literal.create(Seq(20, 200), ArrayType(IntegerType, false))) val miin = Literal.create(null, MapType(IntegerType, IntegerType, valueContainsNull = false)) val multiplyKeyWithValues: (Expression, Expression, Expression) => Expression = { @@ -492,12 +489,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mii0, mii3, multiplyKeyWithValues), Map(1 -> null, 2 -> null, 3 -> null)) - checkEvaluation( - map_zip_with(mii0, mii4, multiplyKeyWithValues), - Map(1 -> null, 2 -> 800, 3 -> null)) - checkEvaluation( - map_zip_with(mii4, mii0, multiplyKeyWithValues), - Map(2 -> 800, 1 -> null, 3 -> null)) checkEvaluation( map_zip_with(mii0, miin, multiplyKeyWithValues), null) @@ -511,9 +502,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val mss2 = Literal.create(Map("c" -> null, "b" -> "t", "a" -> null), MapType(StringType, StringType, valueContainsNull = true)) val mss3 = Literal.create(Map(), MapType(StringType, StringType, valueContainsNull = false)) - val mss4 = MapFromArrays( - Literal.create(Seq("a", "a"), ArrayType(StringType, false)), - Literal.create(Seq("a", "n"), ArrayType(StringType, false))) val mssn = Literal.create(null, MapType(StringType, StringType, valueContainsNull = false)) val concat: (Expression, Expression, Expression) => Expression = { @@ -529,12 +517,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mss0, mss3, concat), Map("a" -> null, "b" -> null, "d" -> null)) - checkEvaluation( - map_zip_with(mss0, mss4, concat), - Map("a" -> "axa", "b" -> null, "d" -> null)) - checkEvaluation( - map_zip_with(mss4, mss0, concat), - Map("a" -> "aax", "b" -> null, "d" -> null)) checkEvaluation( map_zip_with(mss0, mssn, concat), null) @@ -550,9 +532,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper val mbb2 = Literal.create(Map(b(1, 3) -> null, b(1, 2) -> b(2), b(2, 1) -> null), MapType(BinaryType, BinaryType, valueContainsNull = true)) val mbb3 = Literal.create(Map(), MapType(BinaryType, BinaryType, valueContainsNull = false)) - val mbb4 = MapFromArrays( - Literal.create(Seq(b(2, 1), b(2, 1)), ArrayType(BinaryType, false)), - Literal.create(Seq(b(1), b(9)), ArrayType(BinaryType, false))) val mbbn = Literal.create(null, MapType(BinaryType, BinaryType, valueContainsNull = false)) checkEvaluation( @@ -564,12 +543,6 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( map_zip_with(mbb0, mbb3, concat), Map(b(1, 2) -> null, b(2, 1) -> null, b(1, 3) -> null)) - checkEvaluation( - map_zip_with(mbb0, mbb4, concat), - Map(b(1, 2) -> null, b(2, 1) -> b(2, 1, 5, 1), b(1, 3) -> null)) - checkEvaluation( - map_zip_with(mbb4, mbb0, concat), - Map(b(2, 1) -> b(2, 1, 1, 5), b(1, 2) -> null, b(1, 3) -> null)) checkEvaluation( map_zip_with(mbb0, mbbn, concat), null) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala new file mode 100644 index 0000000000000..8509bce177129 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -0,0 +1,105 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} +import org.apache.spark.unsafe.Platform + +class ArrayBasedMapBuilderSuite extends SparkFunSuite { + + test("basic") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, 1) + builder.put(InternalRow(2, 2)) + builder.putAll(new GenericArrayData(Seq(3)), new GenericArrayData(Seq(3))) + val map = builder.build() + assert(map.numElements() == 3) + assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 1, 2 -> 2, 3 -> 3)) + } + + test("fail with null key") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, null) // null value is OK + val e = intercept[RuntimeException](builder.put(null, 1)) + assert(e.getMessage.contains("Cannot use null as map key")) + } + + test("remove duplicated keys with last wins policy") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, 1) + builder.put(2, 2) + builder.put(1, 2) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2)) + } + + test("binary type key") { + val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType) + builder.put(Array(1.toByte), 1) + builder.put(Array(2.toByte), 2) + builder.put(Array(1.toByte), 3) + val map = builder.build() + assert(map.numElements() == 2) + val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq + assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1)) + assert(entries(0)._2 == 3) + assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2)) + assert(entries(1)._2 == 2) + } + + test("struct type key") { + val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType) + builder.put(InternalRow(1), 1) + builder.put(InternalRow(2), 2) + val unsafeRow = { + val row = new UnsafeRow(1) + val bytes = new Array[Byte](16) + row.pointTo(bytes, 16) + row.setInt(0, 1) + row + } + builder.put(unsafeRow, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2)) + } + + test("array type key") { + val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType) + builder.put(new GenericArrayData(Seq(1, 1)), 1) + builder.put(new GenericArrayData(Seq(2, 2)), 2) + val unsafeArray = { + val array = new UnsafeArrayData() + val bytes = new Array[Byte](24) + Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2) + array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24) + array.setInt(0, 1) + array.setInt(1, 1) + array + } + builder.put(unsafeArray, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == + Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2)) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala index 4ecc54bd2fd96..ee16b3ab07f5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcDeserializer.scala @@ -179,6 +179,8 @@ class OrcDeserializer( i += 1 } + // The ORC map will never have null or duplicated map keys, it's safe to create a + // ArrayBasedMapData directly here. updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray)) case udt: UserDefinedType[_] => newWriter(udt.sqlType, updater) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 1199725941842..004a96d134132 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -558,8 +558,12 @@ private[parquet] class ParquetRowConverter( override def getConverter(fieldIndex: Int): Converter = keyValueConverter - override def end(): Unit = + override def end(): Unit = { + // The parquet map may contains null or duplicated map keys. When it happens, the behavior is + // undefined. + // TODO (SPARK-26174): disallow it with a config. updater.set(ArrayBasedMapData(currentKeys.toArray, currentValues.toArray)) + } // NOTE: We can't reuse the mutable Map here and must instantiate a new `Map` for the next // value. `Row.copy()` only copies row cells, it doesn't do deep copy to objects stored in row diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 666ba35d7a8f3..e6d1a038a5918 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -89,13 +89,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val msg1 = intercept[Exception] { df5.select(map_from_arrays($"k", $"v")).collect }.getMessage - assert(msg1.contains("Cannot use null as map key!")) + assert(msg1.contains("Cannot use null as map key")) val df6 = Seq((Seq(1, 2), Seq("a"))).toDF("k", "v") val msg2 = intercept[Exception] { df6.select(map_from_arrays($"k", $"v")).collect }.getMessage - assert(msg2.contains("The given two arrays should have the same length")) + assert(msg2.contains("The key array and value array of MapData must have the same length")) } test("struct with column name") { @@ -2588,7 +2588,7 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext { val ex3 = intercept[Exception] { dfExample1.selectExpr("transform_keys(i, (k, v) -> v)").show() } - assert(ex3.getMessage.contains("Cannot use null as map key!")) + assert(ex3.getMessage.contains("Cannot use null as map key")) val ex4 = intercept[AnalysisException] { dfExample2.selectExpr("transform_keys(j, (k, v) -> k + 1)")