diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 5a5e802f6a900..e94bf8648b605 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -49,7 +49,7 @@ license: | - In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier. - - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined. + - In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, new config `spark.sql.deduplicateMapKey.lastWinsPolicy.enabled` was added, with the default value `false`, Spark will throw RuntimeException while duplicated keys are found. If set to `true`, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined. - In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`. diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index e80d556cc89e3..4c55011f21a56 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2766,6 +2766,7 @@ def map_concat(*cols): :param cols: list of column names (string) or list of :class:`Column` expressions >>> from pyspark.sql.functions import map_concat + >>> spark.conf.set("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled", "true") >>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2") >>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False) +------------------------+ @@ -2773,6 +2774,7 @@ def map_concat(*cols): +------------------------+ |[1 -> d, 2 -> b, 3 -> c]| +------------------------+ + >>> spark.conf.unset("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled") """ sc = SparkContext._active_spark_context if len(cols) == 1 and isinstance(cols[0], (list, set)): diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index 98934368205ec..3b0f78ddf33ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util import scala.collection.mutable import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY import org.apache.spark.sql.types._ import org.apache.spark.unsafe.array.ByteArrayMethods @@ -63,6 +65,11 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria keys.append(key) values.append(value) } else { + if (!SQLConf.get.getConf(DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY)) { + throw new RuntimeException(s"Duplicate map key $key was founded, please set " + + s"${DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key} to true to remove it with " + + "last wins policy.") + } // Overwrite the previous value, as the policy is last wins. values(index) = value } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index b0be37d2b2ee5..2b4cebbaef72b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2167,6 +2167,14 @@ object SQLConf { .booleanConf .createWithDefault(false) + val DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY = + buildConf("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled") + .doc("When true, use last wins policy to remove duplicated map keys in built-in functions, " + + "this config takes effect in below build-in functions: CreateMap, MapFromArrays, " + + "MapFromEntries, StringToMap, MapConcat and TransformKeys.") + .booleanConf + .createWithDefault(false) + /** * Holds information about keys that have been deprecated. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala index 9e98e146c7a0e..99b3ffcbdcbd1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CollectionExpressionsSuite.scala @@ -139,8 +139,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper MapType(IntegerType, IntegerType, valueContainsNull = true)) val mNull = Literal.create(null, MapType(StringType, StringType)) - // overlapping maps should remove duplicated map keys w.r.t. last win policy. - checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3")) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // overlapping maps should remove duplicated map keys w.r.t. last win policy. + checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3")) + } // maps with no overlap checkEvaluation(MapConcat(Seq(m0, m2)), @@ -272,8 +274,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null)) checkEvaluation(MapFromEntries(ai2), Map.empty) checkEvaluation(MapFromEntries(ai3), null) - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20)) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20)) + } // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(ai5), @@ -294,8 +298,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null)) checkEvaluation(MapFromEntries(as2), Map.empty) checkEvaluation(MapFromEntries(as3), null) - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb")) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb")) + } // Map key can't be null checkExceptionInExpression[RuntimeException]( MapFromEntries(as5), diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala index 9039cd6451590..f90bd4981e399 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ComplexTypeSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue} import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String @@ -216,10 +217,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))), "Cannot use null as map key") - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation( - CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))), - create_map(1 -> 3)) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))), + create_map(1 -> 3)) + } // ArrayType map key and value val map = CreateMap(Seq( @@ -281,12 +284,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { MapFromArrays(intWithNullArray, strArray), "Cannot use null as map key") - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation( - MapFromArrays( - Literal.create(Seq(1, 1), ArrayType(IntegerType)), - Literal.create(Seq(2, 3), ArrayType(IntegerType))), - create_map(1 -> 3)) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + MapFromArrays( + Literal.create(Seq(1, 1), ArrayType(IntegerType)), + Literal.create(Seq(2, 3), ArrayType(IntegerType))), + create_map(1 -> 3)) + } // map key can't be map val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b")) @@ -399,10 +404,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper { val m5 = Map("a" -> null) checkEvaluation(new StringToMap(s5), m5) - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation( - new StringToMap(Literal("a:1,b:2,a:3")), - create_map("a" -> "3", "b" -> "2")) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation( + new StringToMap(Literal("a:1,b:2,a:3")), + create_map("a" -> "3", "b" -> "2")) + } // arguments checking assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala index e7b713840b884..c1cd58b51696e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/HigherOrderFunctionsSuite.scala @@ -465,8 +465,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper checkEvaluation( transformKeys(transformKeys(ai0, plusOne), plusValue), create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4)) - // Duplicated map keys will be removed w.r.t. the last wins policy. - checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3)) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + // Duplicated map keys will be removed w.r.t. the last wins policy. + checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3)) + } checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int]) checkEvaluation( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 8509bce177129..f8365aa979fde 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} +import org.apache.spark.sql.catalyst.plans.SQLHelper +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} import org.apache.spark.unsafe.Platform -class ArrayBasedMapBuilderSuite extends SparkFunSuite { +class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper { test("basic") { val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) @@ -43,63 +45,71 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite { } test("remove duplicated keys with last wins policy") { - val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) - builder.put(1, 1) - builder.put(2, 2) - builder.put(1, 2) - val map = builder.build() - assert(map.numElements() == 2) - assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2)) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType) + builder.put(1, 1) + builder.put(2, 2) + builder.put(1, 2) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2)) + } } test("binary type key") { - val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType) - builder.put(Array(1.toByte), 1) - builder.put(Array(2.toByte), 2) - builder.put(Array(1.toByte), 3) - val map = builder.build() - assert(map.numElements() == 2) - val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq - assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1)) - assert(entries(0)._2 == 3) - assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2)) - assert(entries(1)._2 == 2) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType) + builder.put(Array(1.toByte), 1) + builder.put(Array(2.toByte), 2) + builder.put(Array(1.toByte), 3) + val map = builder.build() + assert(map.numElements() == 2) + val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq + assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1)) + assert(entries(0)._2 == 3) + assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2)) + assert(entries(1)._2 == 2) + } } test("struct type key") { - val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType) - builder.put(InternalRow(1), 1) - builder.put(InternalRow(2), 2) - val unsafeRow = { - val row = new UnsafeRow(1) - val bytes = new Array[Byte](16) - row.pointTo(bytes, 16) - row.setInt(0, 1) - row + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType) + builder.put(InternalRow(1), 1) + builder.put(InternalRow(2), 2) + val unsafeRow = { + val row = new UnsafeRow(1) + val bytes = new Array[Byte](16) + row.pointTo(bytes, 16) + row.setInt(0, 1) + row + } + builder.put(unsafeRow, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2)) } - builder.put(unsafeRow, 3) - val map = builder.build() - assert(map.numElements() == 2) - assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2)) } test("array type key") { - val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType) - builder.put(new GenericArrayData(Seq(1, 1)), 1) - builder.put(new GenericArrayData(Seq(2, 2)), 2) - val unsafeArray = { - val array = new UnsafeArrayData() - val bytes = new Array[Byte](24) - Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2) - array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24) - array.setInt(0, 1) - array.setInt(1, 1) - array + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType) + builder.put(new GenericArrayData(Seq(1, 1)), 1) + builder.put(new GenericArrayData(Seq(2, 2)), 2) + val unsafeArray = { + val array = new UnsafeArrayData() + val bytes = new Array[Byte](24) + Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2) + array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24) + array.setInt(0, 1) + array.setInt(1, 1) + array + } + builder.put(unsafeArray, 3) + val map = builder.build() + assert(map.numElements() == 2) + assert(ArrayBasedMapData.toScalaMap(map) == + Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2)) } - builder.put(unsafeArray, 3) - val map = builder.build() - assert(map.numElements() == 2) - assert(ArrayBasedMapData.toScalaMap(map) == - Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index 7fce03658fc16..094ff08a8270d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -651,8 +651,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { Row(null) ) - checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) - checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a) + checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a) + } val expected1b = Seq( Row(Map(1 -> 100, 2 -> 200)), @@ -3068,11 +3070,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)), Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7)))) - checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), - Seq(Row(Map(true -> true, true -> false)))) + withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") { + checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"), + Seq(Row(Map(true -> true, true -> false)))) - checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), - Seq(Row(Map(true -> true, true -> false)))) + checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)), + Seq(Row(Map(true -> true, true -> false)))) + } checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"), Seq(Row(Map(50 -> true, 78 -> false))))