Skip to content

Commit

Permalink
[SPARK-16135][SQL] Remove hashCode and euqals in ArrayBasedMapData
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?
This pr is to remove `hashCode` and `equals` in `ArrayBasedMapData` because the type cannot be used as join keys, grouping keys, or in equality tests.

## How was this patch tested?
Add a new test suite `MapDataSuite` for comparison tests.

Author: Takeshi YAMAMURO <linguin.m.s@gmail.com>

Closes #13847 from maropu/UnsafeMapTest.

(cherry picked from commit 3e4e868)
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
maropu authored and cloud-fan committed Jun 27, 2016
1 parent 664426e commit 22fe336
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ public UnsafeMapData getMap(int ordinal) {
return map;
}

// This `hashCode` computation could consume much processor time for large data.
// If the computation becomes a bottleneck, we can use a light-weight logic; the first fixed bytes
// are used to compute `hashCode` (See `Vector.hashCode`).
// The same issue exists in `UnsafeRow.hashCode`.
@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeBytes(baseObject, baseOffset, sizeInBytes, 42);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,6 @@ class ArrayBasedMapData(val keyArray: ArrayData, val valueArray: ArrayData) exte

override def copy(): MapData = new ArrayBasedMapData(keyArray.copy(), valueArray.copy())

override def equals(o: Any): Boolean = {
if (!o.isInstanceOf[ArrayBasedMapData]) {
return false
}

val other = o.asInstanceOf[ArrayBasedMapData]
if (other eq null) {
return false
}

this.keyArray == other.keyArray && this.valueArray == other.valueArray
}

override def hashCode: Int = {
keyArray.hashCode() * 37 + valueArray.hashCode()
}

override def toString: String = {
s"keys: $keyArray, values: $valueArray"
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ package org.apache.spark.sql.catalyst.util

import org.apache.spark.sql.types.DataType

/**
* This is an internal data representation for map type in Spark SQL. This should not implement
* `equals` and `hashCode` because the type cannot be used as join keys, grouping keys, or
* in equality tests. See SPARK-9415 and PR#13847 for the discussions.
*/
abstract class MapData extends Serializable {

def numElements(): Int
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,10 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper {
case (expr, i) => Seq(Literal(i), expr)
}))
val plan = GenerateMutableProjection.generate(expressions)
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType))
val expected = Seq(new ArrayBasedMapData(
new GenericArrayData(0 until length),
new GenericArrayData(Seq.fill(length)(true))))
val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)).map {
case m: ArrayBasedMapData => ArrayBasedMapData.toScalaMap(m)
}
val expected = (0 until length).map((_, true)).toMap :: Nil

if (!checkResult(actual, expected)) {
fail(s"Incorrect Evaluation: expressions: $expressions, actual: $actual, expected: $expected")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.optimizer.SimpleTestOptimizer
import org.apache.spark.sql.catalyst.plans.logical.{OneRowRelation, Project}
import org.apache.spark.sql.catalyst.util.MapData
import org.apache.spark.sql.types.DataType
import org.apache.spark.util.Utils

Expand All @@ -52,15 +53,18 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks {

/**
* Check the equality between result of expression and expected value, it will handle
* Array[Byte] and Spread[Double].
* Array[Byte], Spread[Double], and MapData.
*/
protected def checkResult(result: Any, expected: Any): Boolean = {
(result, expected) match {
case (result: Array[Byte], expected: Array[Byte]) =>
java.util.Arrays.equals(result, expected)
case (result: Double, expected: Spread[Double @unchecked]) =>
expected.asInstanceOf[Spread[Double]].isWithin(result)
case _ => result == expected
case (result: MapData, expected: MapData) =>
result.keyArray() == expected.keyArray() && result.valueArray() == expected.valueArray()
case _ =>
result == expected
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.expressions

import scala.collection._

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
import org.apache.spark.sql.types.{DataType, IntegerType, MapType, StringType}
import org.apache.spark.unsafe.types.UTF8String

class MapDataSuite extends SparkFunSuite {

test("inequality tests") {
def u(str: String): UTF8String = UTF8String.fromString(str)

// test data
val testMap1 = Map(u("key1") -> 1)
val testMap2 = Map(u("key1") -> 1, u("key2") -> 2)
val testMap3 = Map(u("key1") -> 1)
val testMap4 = Map(u("key1") -> 1, u("key2") -> 2)

// ArrayBasedMapData
val testArrayMap1 = ArrayBasedMapData(testMap1.toMap)
val testArrayMap2 = ArrayBasedMapData(testMap2.toMap)
val testArrayMap3 = ArrayBasedMapData(testMap3.toMap)
val testArrayMap4 = ArrayBasedMapData(testMap4.toMap)
assert(testArrayMap1 !== testArrayMap3)
assert(testArrayMap2 !== testArrayMap4)

// UnsafeMapData
val unsafeConverter = UnsafeProjection.create(Array[DataType](MapType(StringType, IntegerType)))
val row = new GenericMutableRow(1)
def toUnsafeMap(map: ArrayBasedMapData): UnsafeMapData = {
row.update(0, map)
val unsafeRow = unsafeConverter.apply(row)
unsafeRow.getMap(0).copy
}
assert(toUnsafeMap(testArrayMap1) !== toUnsafeMap(testArrayMap3))
assert(toUnsafeMap(testArrayMap2) !== toUnsafeMap(testArrayMap4))
}
}

0 comments on commit 22fe336

Please sign in to comment.