Skip to content

Commit

Permalink
[SPARK-25829][SQL] remove duplicated map keys with last wins policy
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Currently duplicated map keys are not handled consistently. For example, map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc.

This PR proposes to remove duplicated map keys with last wins policy, to follow Java/Scala and Presto. It only applies to built-in functions, as users can create map with duplicated map keys via private APIs anyway.

updated functions: `CreateMap`, `MapFromArrays`, `MapFromEntries`, `StringToMap`, `MapConcat`, `TransformKeys`.

For other places:
1. data source v1 doesn't have this problem, as users need to provide a java/scala map, which can't have duplicated keys.
2. data source v2 may have this problem. I've added a note to `ArrayBasedMapData` to ask the caller to take care of duplicated keys. In the future we should enforce it in the stable data APIs for data source v2.
3. UDF doesn't have this problem, as users need to provide a java/scala map. Same as data source v1.
4. file format. I checked all of them and only parquet does not enforce it. For backward compatibility reasons I change nothing but leave a note saying that the behavior will be undefined if users write map with duplicated keys to parquet files. Maybe we can add a config and fail by default if parquet files have map with duplicated keys. This can be done in followup.

## How was this patch tested?

updated tests and new tests

Closes apache#23124 from cloud-fan/map.

Authored-by: Wenchen Fan <wenchen@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
cloud-fan committed Nov 28, 2018
1 parent 9fde3de commit fa0d4bf
Show file tree
Hide file tree
Showing 22 changed files with 444 additions and 419 deletions.
2 changes: 2 additions & 0 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ displayTitle: Spark SQL Upgrading Guide

- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more.

- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined.

## Upgrading From Spark SQL 2.3 to 2.4

- In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.avro

import java.math.{BigDecimal}
import java.math.BigDecimal
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -218,6 +218,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
i += 1
}

// The Avro map will never have null or duplicated map keys, it's safe to create a
// ArrayBasedMapData directly here.
updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))

case (UNION, _) =>
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,11 +2656,11 @@ def map_concat(*cols):
>>> from pyspark.sql.functions import map_concat
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+--------------------------------+
|map3 |
+--------------------------------+
|[1 -> a, 2 -> b, 3 -> c, 1 -> d]|
+--------------------------------+
+------------------------+
|map3 |
+------------------------+
|[1 -> d, 2 -> b, 3 -> c]|
+------------------------+
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head
* to indicate the number of bytes of the unsafe key array.
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*
* Note that, user is responsible to guarantee that the key array does not have duplicated
* elements, otherwise the behavior is undefined.
*/
// TODO: Use a more efficient format which doesn't depend on unsafe array.
public final class UnsafeMapData extends MapData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,6 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case (keys: Array[_], values: Array[_]) =>
// case for mapdata with duplicate keys
new ArrayBasedMapData(
new GenericArrayData(keys.map(convertToCatalyst)),
new GenericArrayData(values.map(convertToCatalyst))
)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,36 @@ object InternalRow {
* actually takes a `SpecializedGetters` input because it can be generalized to other classes
* that implements `SpecializedGetters` (e.g., `ArrayData`) too.
*/
def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType)
case _ => (input, ordinal) => input.get(ordinal, dataType)
def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = {
val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable)
case _ => (input, ordinal) => input.get(ordinal, dt)
}

if (nullable) {
(getter, index) => {
if (getter.isNullAt(index)) {
null
} else {
getValueNullSafe(getter, index)
}
}
} else {
getValueNullSafe
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable)

// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (nullable && input.isNullAt(ordinal)) {
null
} else {
accessor(input, ordinal)
}
accessor(input, ordinal)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Loading

0 comments on commit fa0d4bf

Please sign in to comment.