From 6dff6545f272e0d5117ac17fdc27b686573c5626 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Nov 2018 15:44:08 +0800 Subject: [PATCH] address comments --- .../sql/catalyst/util/ArrayBasedMapBuilder.scala | 16 +++++++++------- .../util/ArrayBasedMapBuilderSuite.scala | 16 +++++++++++++++- 2 files changed, 24 insertions(+), 8 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala index fda2d748d7c8a..e7cd61655dc9a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilder.scala @@ -31,11 +31,13 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria assert(keyType != NullType, "map key cannot be null type.") private lazy val keyToIndex = keyType match { - case _: AtomicType | _: CalendarIntervalType => mutable.HashMap.empty[Any, Int] + // Binary type data is `byte[]`, which can't use `==` to check equality. + case _: AtomicType | _: CalendarIntervalType if !keyType.isInstanceOf[BinaryType] => + new java.util.HashMap[Any, Int]() case _ => // for complex types, use interpreted ordering to be able to compare unsafe data with safe // data, e.g. UnsafeRow vs GenericInternalRow. - mutable.TreeMap.empty[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) + new java.util.TreeMap[Any, Int](TypeUtils.getInterpretedOrdering(keyType)) } // TODO: specialize it @@ -50,14 +52,14 @@ class ArrayBasedMapBuilder(keyType: DataType, valueType: DataType) extends Seria throw new RuntimeException("Cannot use null as map key.") } - val maybeExistingIdx = keyToIndex.get(key) - if (maybeExistingIdx.isDefined) { - // Overwrite the previous value, as the policy is last wins. - values(maybeExistingIdx.get) = value - } else { + val index = keyToIndex.getOrDefault(key, -1) + if (index == -1) { keyToIndex.put(key, values.length) keys.append(key) values.append(value) + } else { + // Overwrite the previous value, as the policy is last wins. + values(index) = value } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala index 4fb99be96a51e..8509bce177129 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/ArrayBasedMapBuilderSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.util import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{UnsafeArrayData, UnsafeRow} -import org.apache.spark.sql.types.{ArrayType, IntegerType, StructType} +import org.apache.spark.sql.types.{ArrayType, BinaryType, IntegerType, StructType} import org.apache.spark.unsafe.Platform class ArrayBasedMapBuilderSuite extends SparkFunSuite { @@ -52,6 +52,20 @@ class ArrayBasedMapBuilderSuite extends SparkFunSuite { assert(ArrayBasedMapData.toScalaMap(map) == Map(1 -> 2, 2 -> 2)) } + test("binary type key") { + val builder = new ArrayBasedMapBuilder(BinaryType, IntegerType) + builder.put(Array(1.toByte), 1) + builder.put(Array(2.toByte), 2) + builder.put(Array(1.toByte), 3) + val map = builder.build() + assert(map.numElements() == 2) + val entries = ArrayBasedMapData.toScalaMap(map).iterator.toSeq + assert(entries(0)._1.asInstanceOf[Array[Byte]].toSeq == Seq(1)) + assert(entries(0)._2 == 3) + assert(entries(1)._1.asInstanceOf[Array[Byte]].toSeq == Seq(2)) + assert(entries(1)._2 == 2) + } + test("struct type key") { val builder = new ArrayBasedMapBuilder(new StructType().add("i", "int"), IntegerType) builder.put(InternalRow(1), 1)