Skip to content

Commit

Permalink
[SPARK-25829][SQL] Add config `spark.sql.legacy.allowDuplicatedMapKey…
Browse files Browse the repository at this point in the history
…s` and change the default behavior

### What changes were proposed in this pull request?
This is a follow-up for #23124, add a new config `spark.sql.legacy.allowDuplicatedMapKeys` to control the behavior of removing duplicated map keys in build-in functions. With the default value `false`, Spark will throw a RuntimeException while duplicated keys are found.

### Why are the changes needed?
Prevent silent behavior changes.

### Does this PR introduce any user-facing change?
Yes, new config added and the default behavior for duplicated map keys changed to RuntimeException thrown.

### How was this patch tested?
Modify existing UT.

Closes #27478 from xuanyuanking/SPARK-25892-follow.

Authored-by: Yuanjian Li <xyliyuanjian@gmail.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
(cherry picked from commit ab186e3)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
xuanyuanking authored and cloud-fan committed Feb 17, 2020
1 parent cea5cbc commit 33329ca
Show file tree
Hide file tree
Showing 10 changed files with 136 additions and 80 deletions.
2 changes: 1 addition & 1 deletion docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ license: |

- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier.

- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.
- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, new config `spark.sql.legacy.allowDuplicatedMapKeys` was added, with the default value `false`, Spark will throw RuntimeException while duplicated keys are found. If set to `true`, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.

- In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`.

Expand Down
4 changes: 2 additions & 2 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,12 +2766,12 @@ def map_concat(*cols):
:param cols: list of column names (string) or list of :class:`Column` expressions
>>> from pyspark.sql.functions import map_concat
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c') as map2")
>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+------------------------+
|map3 |
+------------------------+
|[1 -> d, 2 -> b, 3 -> c]|
|[1 -> a, 2 -> b, 3 -> c]|
+------------------------+
"""
sc = SparkContext._active_spark_context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -516,8 +516,8 @@ case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInp
usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
{1:"a",2:"c",3:"d"}
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(3, 'c'));
{1:"a",2:"b",3:"c"}
""", since = "2.4.0")
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

Expand Down Expand Up @@ -47,6 +49,9 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
private lazy val keyGetter = InternalRow.getAccessor(keyType)
private lazy val valueGetter = InternalRow.getAccessor(valueType)

private val allowDuplicatedMapKey =
SQLConf.get.getConf(LEGACY_ALLOW_DUPLICATED_MAP_KEY)

def put(key: Any, value: Any): Unit = {
if (key == null) {
throw new RuntimeException("Cannot use null as map key.")
Expand All @@ -62,6 +67,11 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
keys.append(key)
values.append(value)
} else {
if (!allowDuplicatedMapKey) {
throw new RuntimeException(s"Duplicate map key $key was founded, please check the input " +
"data. If you want to remove the duplicated keys with last-win policy, you can set " +
s"${LEGACY_ALLOW_DUPLICATED_MAP_KEY.key} to true.")
}
// Overwrite the previous value, as the policy is last wins.
values(index) = value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2188,6 +2188,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val LEGACY_ALLOW_DUPLICATED_MAP_KEY =
buildConf("spark.sql.legacy.allowDuplicatedMapKeys")
.doc("When true, use last wins policy to remove duplicated map keys in built-in functions, " +
"this config takes effect in below build-in functions: CreateMap, MapFromArrays, " +
"MapFromEntries, StringToMap, MapConcat and TransformKeys. Otherwise, if this is false, " +
"which is the default, Spark will throw an exception when duplicated map keys are " +
"detected.")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mNull = Literal.create(null, MapType(StringType, StringType))

// overlapping maps should remove duplicated map keys w.r.t. last win policy.
checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3"))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// overlapping maps should remove duplicated map keys w.r.t. last win policy.
checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3"))
}

// maps with no overlap
checkEvaluation(MapConcat(Seq(m0, m2)),
Expand Down Expand Up @@ -272,8 +274,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20))
}
// Map key can't be null
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
Expand All @@ -294,8 +298,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb"))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb"))
}
// Map key can't be null
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -216,10 +217,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
"Cannot use null as map key")

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
create_map(1 -> 3))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
create_map(1 -> 3))
}

// ArrayType map key and value
val map = CreateMap(Seq(
Expand Down Expand Up @@ -281,12 +284,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
MapFromArrays(intWithNullArray, strArray),
"Cannot use null as map key")

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
create_map(1 -> 3))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
create_map(1 -> 3))
}

// map key can't be map
val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b"))
Expand Down Expand Up @@ -399,10 +404,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val m5 = Map("a" -> null)
checkEvaluation(new StringToMap(s5), m5)

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
new StringToMap(Literal("a:1,b:2,a:3")),
create_map("a" -> "3", "b" -> "2"))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
new StringToMap(Literal("a:1,b:2,a:3")),
create_map("a" -> "3", "b" -> "2"))
}

// arguments checking
assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue),
create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3))
}
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType}
import org.apache.spark.unsafe.Platform

class ArrayBasedMapBuilderSuite extends SparkFunSuite {
class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper {

test("basic") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
Expand All @@ -42,64 +44,79 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite {
assert(e.getMessage.contains("Cannot use null as map key"))
}

test("remove duplicated keys with last wins policy") {
test("fail while duplicated keys detected") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
builder.put(1, 1)
builder.put(2, 2)
builder.put(1, 2)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2))
val e = intercept[RuntimeException](builder.put(1, 2))
assert(e.getMessage.contains("Duplicate map key 1 was founded"))
}

test("remove duplicated keys with last wins policy") {
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
builder.put(1, 1)
builder.put(2, 2)
builder.put(1, 2)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2))
}
}

test("binary type key") {
val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType)
builder.put(Array(1.toByte), 1)
builder.put(Array(2.toByte), 2)
builder.put(Array(1.toByte), 3)
val map = builder.build()
assert(map.numElements() == 2)
val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq
assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1))
assert(entries(0)._2 == 3)
assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2))
assert(entries(1)._2 == 2)
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType)
builder.put(Array(1.toByte), 1)
builder.put(Array(2.toByte), 2)
builder.put(Array(1.toByte), 3)
val map = builder.build()
assert(map.numElements() == 2)
val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq
assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1))
assert(entries(0)._2 == 3)
assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2))
assert(entries(1)._2 == 2)
}
}

test("struct type key") {
val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType)
builder.put(InternalRow(1), 1)
builder.put(InternalRow(2), 2)
val unsafeRow = {
val row = new UnsafeRow(1)
val bytes = new Array[Byte](16)
row.pointTo(bytes, 16)
row.setInt(0, 1)
row
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType)
builder.put(InternalRow(1), 1)
builder.put(InternalRow(2), 2)
val unsafeRow = {
val row = new UnsafeRow(1)
val bytes = new Array[Byte](16)
row.pointTo(bytes, 16)
row.setInt(0, 1)
row
}
builder.put(unsafeRow, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2))
}
builder.put(unsafeRow, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2))
}

test("array type key") {
val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType)
builder.put(new GenericArrayData(Seq(1, 1)), 1)
builder.put(new GenericArrayData(Seq(2, 2)), 2)
val unsafeArray = {
val array = new UnsafeArrayData()
val bytes = new Array[Byte](24)
Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2)
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24)
array.setInt(0, 1)
array.setInt(1, 1)
array
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType)
builder.put(new GenericArrayData(Seq(1, 1)), 1)
builder.put(new GenericArrayData(Seq(2, 2)), 2)
val unsafeArray = {
val array = new UnsafeArrayData()
val bytes = new Array[Byte](24)
Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2)
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24)
array.setInt(0, 1)
array.setInt(1, 1)
array
}
builder.put(unsafeArray, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) ==
Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2))
}
builder.put(unsafeArray, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) ==
Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Row(null)
)

checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
}

val expected1b = Seq(
Row(Map(1 -> 100, 2 -> 200)),
Expand Down Expand Up @@ -3068,11 +3070,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))
withSQLConf(SQLConf.LEGACY_ALLOW_DUPLICATED_MAP_KEY.key -> "true") {
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)),
Seq(Row(Map(true -> true, true -> false))))
checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)),
Seq(Row(Map(true -> true, true -> false))))
}

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))
Expand Down

0 comments on commit 33329ca

Please sign in to comment.