Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
cloud-fan committed Nov 28, 2018
1 parent b7073b2 commit 6dff654
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
assert(keyType != NullType, "map key cannot be null type.")

private lazy val keyToIndex = keyType match {
case _: AtomicType | _: CalendarIntervalType => mutable.HashMap.empty[Any, Int]
// Binary type data is `byte[]`, which can't use `==` to check equality.
case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] =>
new java.util.HashMap[Any, Int]()
case _ =>
// for complex types, use interpreted ordering to be able to compare unsafe data with safe
// data, e.g. UnsafeRow vs GenericInternalRow.
mutable.TreeMap.empty[Any, Int](TypeUtils.getInterpretedOrdering(keyType))
new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType))
}

// TODO: specialize it
Expand All @@ -50,14 +52,14 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria
throw new RuntimeException("Cannot use null as map key.")
}

val maybeExistingIdx = keyToIndex.get(key)
if (maybeExistingIdx.isDefined) {
// Overwrite the previous value, as the policy is last wins.
values(maybeExistingIdx.get) = value
} else {
val index = keyToIndex.getOrDefault(key, -1)
if (index == -1) {
keyToIndex.put(key, values.length)
keys.append(key)
values.append(value)
} else {
// Overwrite the previous value, as the policy is last wins.
values(index) = value
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow}
import org.apache.spark.sql.types.{ArrayType, IntegerType, StructType}
import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType}
import org.apache.spark.unsafe.Platform

class ArrayBasedMapBuilderSuite extends SparkFunSuite {
Expand Down Expand Up @@ -52,6 +52,20 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite {
assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2))
}

test("binary type key") {
val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType)
builder.put(Array(1.toByte), 1)
builder.put(Array(2.toByte), 2)
builder.put(Array(1.toByte), 3)
val map = builder.build()
assert(map.numElements() == 2)
val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq
assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1))
assert(entries(0)._2 == 3)
assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2))
assert(entries(1)._2 == 2)
}

test("struct type key") {
val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType)
builder.put(InternalRow(1), 1)
Expand Down

0 comments on commit 6dff654

Please sign in to comment.