Skip to content

Commit

Permalink
Add config for the defaul behavior
Browse files Browse the repository at this point in the history
  • Loading branch information
xuanyuanking committed Feb 6, 2020
1 parent d861357 commit 09cfb67
Show file tree
Hide file tree
Showing 9 changed files with 123 additions and 77 deletions.
2 changes: 1 addition & 1 deletion docs/sql-migration-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ license: |

- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but -0.0 and 0.0 are considered as different values when used in aggregate grouping keys, window partition keys and join keys. Since Spark 3.0, this bug is fixed. For example, `Seq(-0.0, 0.0).toDF("d").groupBy("d").count()` returns `[(0.0, 2)]` in Spark 3.0, and `[(0.0, 1), (-0.0, 1)]` in Spark 2.4 and earlier.

- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.
- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, new config `spark.sql.deduplicateMapKey.lastWinsPolicy.enabled` was added, with the default value `false`, Spark will throw RuntimeException while duplicated keys are found. If set to `true`, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be undefined.

- In Spark version 2.4 and earlier, partition column value is converted as null if it can't be casted to corresponding user provided schema. Since 3.0, partition column value is validated with user provided schema. An exception is thrown if the validation fails. You can disable such validation by setting `spark.sql.sources.validatePartitionColumns` to `false`.

Expand Down
2 changes: 2 additions & 0 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2766,13 +2766,15 @@ def map_concat(*cols):
:param cols: list of column names (string) or list of :class:`Column` expressions
>>> from pyspark.sql.functions import map_concat
>>> spark.conf.set("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled", "true")
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+------------------------+
|map3 |
+------------------------+
|[1 -> d, 2 -> b, 3 -> c]|
+------------------------+
>>> spark.conf.unset("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled")
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ package org.apache.spark.sql.catalyst.util
import scala.collection.mutable

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.internal.SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.array.ByteArrayMethods

Expand Down Expand Up @@ -63,6 +65,11 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
keys.append(key)
values.append(value)
} else {
if (!SQLConf.get.getConf(DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY)) {
throw new RuntimeException(s"Duplicate map key $key was founded, please set " +
s"${DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key} to true to remove it with " +
"last wins policy.")
}
// Overwrite the previous value, as the policy is last wins.
values(index) = value
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2167,6 +2167,14 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY =
buildConf("spark.sql.deduplicateMapKey.lastWinsPolicy.enabled")
.doc("When true, use last wins policy to remove duplicated map keys in built-in functions, " +
"this config takes effect in below build-in functions: CreateMap, MapFromArrays, " +
"MapFromEntries, StringToMap, MapConcat and TransformKeys.")
.booleanConf
.createWithDefault(false)

/**
* Holds information about keys that have been deprecated.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -139,8 +139,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
MapType(IntegerType, IntegerType, valueContainsNull = true))
val mNull = Literal.create(null, MapType(StringType, StringType))

// overlapping maps should remove duplicated map keys w.r.t. last win policy.
checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3"))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// overlapping maps should remove duplicated map keys w.r.t. last win policy.
checkEvaluation(MapConcat(Seq(m0, m1)), create_map("a" -> "4", "b" -> "2", "c" -> "3"))
}

// maps with no overlap
checkEvaluation(MapConcat(Seq(m0, m2)),
Expand Down Expand Up @@ -272,8 +274,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapFromEntries(ai1), create_map(1 -> null, 2 -> 20, 3 -> null))
checkEvaluation(MapFromEntries(ai2), Map.empty)
checkEvaluation(MapFromEntries(ai3), null)
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(ai4), create_map(1 -> 20))
}
// Map key can't be null
checkExceptionInExpression[RuntimeException](
MapFromEntries(ai5),
Expand All @@ -294,8 +298,10 @@ class CollectionExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(MapFromEntries(as1), create_map("a" -> null, "b" -> "bb", "c" -> null))
checkEvaluation(MapFromEntries(as2), Map.empty)
checkEvaluation(MapFromEntries(as3), null)
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb"))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(MapFromEntries(as4), create_map("a" -> "bb"))
}
// Map key can't be null
checkExceptionInExpression[RuntimeException](
MapFromEntries(as5),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, UnresolvedExtractValue}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String

Expand Down Expand Up @@ -216,10 +217,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
CreateMap(interlace(strWithNull, intSeq.map(Literal(_)))),
"Cannot use null as map key")

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
create_map(1 -> 3))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
CreateMap(Seq(Literal(1), Literal(2), Literal(1), Literal(3))),
create_map(1 -> 3))
}

// ArrayType map key and value
val map = CreateMap(Seq(
Expand Down Expand Up @@ -281,12 +284,14 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
MapFromArrays(intWithNullArray, strArray),
"Cannot use null as map key")

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
create_map(1 -> 3))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
MapFromArrays(
Literal.create(Seq(1, 1), ArrayType(IntegerType)),
Literal.create(Seq(2, 3), ArrayType(IntegerType))),
create_map(1 -> 3))
}

// map key can't be map
val arrayOfMap = Seq(create_map(1 -> "a", 2 -> "b"))
Expand Down Expand Up @@ -399,10 +404,12 @@ class ComplexTypeSuite extends SparkFunSuite with ExpressionEvalHelper {
val m5 = Map("a" -> null)
checkEvaluation(new StringToMap(s5), m5)

// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
new StringToMap(Literal("a:1,b:2,a:3")),
create_map("a" -> "3", "b" -> "2"))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(
new StringToMap(Literal("a:1,b:2,a:3")),
create_map("a" -> "3", "b" -> "2"))
}

// arguments checking
assert(new StringToMap(Literal("a:1,b:2,c:3")).checkInputDataTypes().isSuccess)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,8 +465,10 @@ class HigherOrderFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper
checkEvaluation(
transformKeys(transformKeys(ai0, plusOne), plusValue),
create_map(3 -> 1, 5 -> 2, 7 -> 3, 9 -> 4))
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
// Duplicated map keys will be removed w.r.t. the last wins policy.
checkEvaluation(transformKeys(ai0, modKey), create_map(1 -> 4, 2 -> 2, 0 -> 3))
}
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(transformKeys(ai1, plusOne), Map.empty[Int, Int])
checkEvaluation(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType}
import org.apache.spark.unsafe.Platform

class ArrayBasedMapBuilderSuite extends SparkFunSuite {
class ArrayBasedMapBuilderSuite extends SparkFunSuite with SQLHelper {

test("basic") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
Expand All @@ -43,63 +45,71 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite {
}

test("remove duplicated keys with last wins policy") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
builder.put(1, 1)
builder.put(2, 2)
builder.put(1, 2)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
val builder = new ArrayBasedMapBuilder(IntegerType, IntegerType)
builder.put(1, 1)
builder.put(2, 2)
builder.put(1, 2)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2))
}
}

test("binary type key") {
val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType)
builder.put(Array(1.toByte), 1)
builder.put(Array(2.toByte), 2)
builder.put(Array(1.toByte), 3)
val map = builder.build()
assert(map.numElements() == 2)
val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq
assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1))
assert(entries(0)._2 == 3)
assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2))
assert(entries(1)._2 == 2)
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType)
builder.put(Array(1.toByte), 1)
builder.put(Array(2.toByte), 2)
builder.put(Array(1.toByte), 3)
val map = builder.build()
assert(map.numElements() == 2)
val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq
assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1))
assert(entries(0)._2 == 3)
assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2))
assert(entries(1)._2 == 2)
}
}

test("struct type key") {
val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType)
builder.put(InternalRow(1), 1)
builder.put(InternalRow(2), 2)
val unsafeRow = {
val row = new UnsafeRow(1)
val bytes = new Array[Byte](16)
row.pointTo(bytes, 16)
row.setInt(0, 1)
row
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType)
builder.put(InternalRow(1), 1)
builder.put(InternalRow(2), 2)
val unsafeRow = {
val row = new UnsafeRow(1)
val bytes = new Array[Byte](16)
row.pointTo(bytes, 16)
row.setInt(0, 1)
row
}
builder.put(unsafeRow, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2))
}
builder.put(unsafeRow, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) == Map(InternalRow(1) -> 3, InternalRow(2) -> 2))
}

test("array type key") {
val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType)
builder.put(new GenericArrayData(Seq(1, 1)), 1)
builder.put(new GenericArrayData(Seq(2, 2)), 2)
val unsafeArray = {
val array = new UnsafeArrayData()
val bytes = new Array[Byte](24)
Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2)
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24)
array.setInt(0, 1)
array.setInt(1, 1)
array
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
val builder = new ArrayBasedMapBuilder(ArrayType(IntegerType), IntegerType)
builder.put(new GenericArrayData(Seq(1, 1)), 1)
builder.put(new GenericArrayData(Seq(2, 2)), 2)
val unsafeArray = {
val array = new UnsafeArrayData()
val bytes = new Array[Byte](24)
Platform.putLong(bytes, Platform.BYTE_ARRAY_OFFSET, 2)
array.pointTo(bytes, Platform.BYTE_ARRAY_OFFSET, 24)
array.setInt(0, 1)
array.setInt(1, 1)
array
}
builder.put(unsafeArray, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) ==
Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2))
}
builder.put(unsafeArray, 3)
val map = builder.build()
assert(map.numElements() == 2)
assert(ArrayBasedMapData.toScalaMap(map) ==
Map(new GenericArrayData(Seq(1, 1)) -> 3, new GenericArrayData(Seq(2, 2)) -> 2))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -651,8 +651,10 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
Row(null)
)

checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
checkAnswer(df1.selectExpr("map_concat(map1, map2)"), expected1a)
checkAnswer(df1.select(map_concat($"map1", $"map2")), expected1a)
}

val expected1b = Seq(
Row(Map(1 -> 100, 2 -> 200)),
Expand Down Expand Up @@ -3068,11 +3070,13 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession {
checkAnswer(dfExample2.select(transform_keys(col("j"), (k, v) => k + v)),
Seq(Row(Map(2.0 -> 1.0, 3.4 -> 1.4, 4.7 -> 1.7))))

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))
withSQLConf(SQLConf.DEDUPLICATE_MAP_KEY_WITH_LAST_WINS_POLICY.key -> "true") {
checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> k % 2 = 0 OR v)"),
Seq(Row(Map(true -> true, true -> false))))

checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)),
Seq(Row(Map(true -> true, true -> false))))
checkAnswer(dfExample3.select(transform_keys(col("x"), (k, v) => k % 2 === 0 || v)),
Seq(Row(Map(true -> true, true -> false))))
}

checkAnswer(dfExample3.selectExpr("transform_keys(x, (k, v) -> if(v, 2 * k, 3 * k))"),
Seq(Row(Map(50 -> true, 78 -> false))))
Expand Down

0 comments on commit 09cfb67

Please sign in to comment.