Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25829][SQL] remove duplicated map keys with last wins policy #23124

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/sql-migration-guide-upgrade.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ displayTitle: Spark SQL Upgrading Guide

- In Spark version 2.4 and earlier, float/double -0.0 is semantically equal to 0.0, but users can still distinguish them via `Dataset.show`, `Dataset.collect` etc. Since Spark 3.0, float/double -0.0 is replaced by 0.0 internally, and users can't distinguish them any more.

- In Spark version 2.4 and earlier, users can create a map with duplicated keys via built-in functions like `CreateMap`, `StringToMap`, etc. The behavior of map with duplicated keys is undefined, e.g. map look up respects the duplicated key appears first, `Dataset.collect` only keeps the duplicated key appears last, `MapKeys` returns duplicated keys, etc. Since Spark 3.0, these built-in functions will remove duplicated map keys with last wins policy. Users may still read map values with duplicated keys from data sources which do not enforce it (e.g. Parquet), the behavior will be udefined.
Copy link
Member

@gatorsmile gatorsmile Dec 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A few typos.

In Spark version 2.4 and earlier, users can create a map with duplicate keys via built-in functions like CreateMap and StringToMap. The behavior of map with duplicate keys is undefined. For example, the map lookup respects the duplicate key that appears first, Dataset.collect only keeps the duplicate key that appears last, and MapKeys returns duplicate keys. Since Spark 3.0, these built-in functions will remove duplicate map keys using the last-one-wins policy. Users may still read map values with duplicate keys from the data sources that do not enforce it (e.g. Parquet), but the behavior will be undefined.


## Upgrading From Spark SQL 2.3 to 2.4

- In Spark version 2.3 and earlier, the second parameter to array_contains function is implicitly promoted to the element type of first array type parameter. This type promotion can be lossy and may cause `array_contains` function to return wrong result. This problem has been addressed in 2.4 by employing a safer type promotion mechanism. This can cause some change in behavior and are illustrated in the table below.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.avro

import java.math.{BigDecimal}
import java.math.BigDecimal
import java.nio.ByteBuffer

import scala.collection.JavaConverters._
Expand Down Expand Up @@ -218,6 +218,8 @@ class AvroDeserializer(rootAvroType: Schema, rootCatalystType: DataType) {
i += 1
}

// The Avro map will never have null or duplicated map keys, it's safe to create a
// ArrayBasedMapData directly here.
updater.set(ordinal, new ArrayBasedMapData(keyArray, valueArray))

case (UNION, _) =>
Expand Down
10 changes: 5 additions & 5 deletions python/pyspark/sql/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2656,11 +2656,11 @@ def map_concat(*cols):
>>> from pyspark.sql.functions import map_concat
>>> df = spark.sql("SELECT map(1, 'a', 2, 'b') as map1, map(3, 'c', 1, 'd') as map2")
>>> df.select(map_concat("map1", "map2").alias("map3")).show(truncate=False)
+--------------------------------+
|map3 |
+--------------------------------+
|[1 -> a, 2 -> b, 3 -> c, 1 -> d]|
+--------------------------------+
+------------------------+
|map3 |
+------------------------+
|[1 -> d, 2 -> b, 3 -> c]|
+------------------------+
"""
sc = SparkContext._active_spark_context
if len(cols) == 1 and isinstance(cols[0], (list, set)):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
* Currently we just use 2 UnsafeArrayData to represent UnsafeMapData, with extra 8 bytes at head
* to indicate the number of bytes of the unsafe key array.
* [unsafe key array numBytes] [unsafe key array] [unsafe value array]
*
* Note that, user is responsible to guarantee that the key array does not have duplicated
* elements, otherwise the behavior is undefined.
*/
// TODO: Use a more efficient format which doesn't depend on unsafe array.
public final class UnsafeMapData extends MapData {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -431,12 +431,6 @@ object CatalystTypeConverters {
map,
(key: Any) => convertToCatalyst(key),
(value: Any) => convertToCatalyst(value))
case (keys: Array[_], values: Array[_]) =>
// case for mapdata with duplicate keys
new ArrayBasedMapData(
new GenericArrayData(keys.map(convertToCatalyst)),
new GenericArrayData(values.map(convertToCatalyst))
)
case other => other
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,22 +125,36 @@ object InternalRow {
* actually takes a `SpecializedGetters` input because it can be generalized to other classes
* that implements `SpecializedGetters` (e.g., `ArrayData`) too.
*/
def getAccessor(dataType: DataType): (SpecializedGetters, Int) => Any = dataType match {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType)
case _ => (input, ordinal) => input.get(ordinal, dataType)
def getAccessor(dt: DataType, nullable: Boolean = true): (SpecializedGetters, Int) => Any = {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can move it to a new PR if others think it's necessary. It's a little dangerous to ask the caller side to take care of null values.

val getValueNullSafe: (SpecializedGetters, Int) => Any = dt match {
case BooleanType => (input, ordinal) => input.getBoolean(ordinal)
case ByteType => (input, ordinal) => input.getByte(ordinal)
case ShortType => (input, ordinal) => input.getShort(ordinal)
case IntegerType | DateType => (input, ordinal) => input.getInt(ordinal)
case LongType | TimestampType => (input, ordinal) => input.getLong(ordinal)
case FloatType => (input, ordinal) => input.getFloat(ordinal)
case DoubleType => (input, ordinal) => input.getDouble(ordinal)
case StringType => (input, ordinal) => input.getUTF8String(ordinal)
case BinaryType => (input, ordinal) => input.getBinary(ordinal)
case CalendarIntervalType => (input, ordinal) => input.getInterval(ordinal)
case t: DecimalType => (input, ordinal) => input.getDecimal(ordinal, t.precision, t.scale)
case t: StructType => (input, ordinal) => input.getStruct(ordinal, t.size)
case _: ArrayType => (input, ordinal) => input.getArray(ordinal)
case _: MapType => (input, ordinal) => input.getMap(ordinal)
case u: UserDefinedType[_] => getAccessor(u.sqlType, nullable)
case _ => (input, ordinal) => input.get(ordinal, dt)
}

if (nullable) {
(getter, index) => {
if (getter.isNullAt(index)) {
null
} else {
getValueNullSafe(getter, index)
}
}
} else {
getValueNullSafe
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@ case class BoundReference(ordinal: Int, dataType: DataType, nullable: Boolean)

override def toString: String = s"input[$ordinal, ${dataType.simpleString}, $nullable]"

private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType)
private val accessor: (InternalRow, Int) => Any = InternalRow.getAccessor(dataType, nullable)

// Use special getter for primitive types (for UnsafeRow)
override def eval(input: InternalRow): Any = {
if (nullable && input.isNullAt(ordinal)) {
null
} else {
accessor(input, ordinal)
}
accessor(input, ordinal)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Expand Down
Loading