Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
maropu committed May 16, 2021
1 parent 8308abe commit 38e42c4
Show file tree
Hide file tree
Showing 27 changed files with 1,474 additions and 688 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

import org.apache.spark.sql.catalyst.util.MapData;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.hash.Murmur3_x86_32;

import static org.apache.spark.unsafe.Platform.BYTE_ARRAY_OFFSET;

Expand Down Expand Up @@ -112,6 +114,22 @@ public UnsafeArrayData valueArray() {
return values;
}

@Override
public int hashCode() {
return Murmur3_x86_32.hashUnsafeBytes(baseObject, baseOffset, sizeInBytes, 42);
}

@Override
public boolean equals(Object other) {
if (other instanceof UnsafeMapData) {
UnsafeMapData o = (UnsafeMapData) other;
return (sizeInBytes == o.sizeInBytes) &&
ByteArrayMethods.arrayEquals(baseObject, baseOffset, o.baseObject, o.baseOffset,
sizeInBytes);
}
return false;
}

public void writeToMemory(Object target, long targetOffset) {
Platform.copyMemory(baseObject, baseOffset, target, targetOffset, sizeInBytes);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
dt.existsRecursively(_.isInstanceOf[MapType])
}

protected def mapColumnInSetOperation(plan: LogicalPlan): Option[Attribute] = plan match {
case _: Intersect | _: Except | _: Distinct =>
plan.output.find(a => hasMapType(a.dataType))
case d: Deduplicate =>
d.keys.find(a => hasMapType(a.dataType))
case _ => None
}

private def checkLimitLikeClause(name: String, limitExpr: Expression): Unit = {
limitExpr match {
case e if !e.foldable => failAnalysis(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,118 +689,32 @@ class CodegenContext extends Logging {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case _ @ MapType(keyType, valueType, valueContainsNull) =>
val compareMapFunc = freshName("compareMap")
val initIndexArrayFunc = freshName("initIndexArray")
val keyIndexComparator = freshName("keyIndexComparator")
val compareKeyFunc = freshName("compareKey")
val compareValueFunc = freshName("compareValue")
val nullSafeCompare =
s"""
|${javaType(valueType)} left = ${getValue("leftArray", valueType, "leftIndex")};
|${javaType(valueType)} right = ${getValue("rightArray", valueType, "rightIndex")};
|return ${genComp(valueType, "left", "right")};
|""".stripMargin
val compareElement = if (valueContainsNull) {
s"""
|boolean isNullA = leftArray.isNullAt(leftIndex);
|boolean isNullB = rightArray.isNullAt(rightIndex);
|if (isNullA && isNullB) {
| return 0;
|} else if (isNullA) {
| return -1;
|} else if (isNullB) {
| return 1;
|} else {
| $nullSafeCompare
|}
|""".stripMargin
} else {
nullSafeCompare
}

addNewFunction(initIndexArrayFunc,
s"""
|private Integer[] $initIndexArrayFunc(int n) {
| Integer[] arr = new Integer[n];
| for (int i = 0; i < n; i++) {
| arr[i] = i;
| }
| return arr;
|}""".stripMargin)


addNewFunction(keyIndexComparator,
s"""
|private class $keyIndexComparator implements java.util.Comparator<Integer> {
| private ArrayData array;
| public $keyIndexComparator(ArrayData array) {
| this.array = array;
| }
|
| @Override
| public int compare(Object a, Object b) {
| int indexA = ((Integer)a).intValue();
| int indexB = ((Integer)b).intValue();
| ${javaType(keyType)} keyA = ${getValue("array", keyType, "indexA")};
| ${javaType(keyType)} keyB = ${getValue("array", keyType, "indexB")};
| return ${genComp(keyType, "keyA", "keyB")};
| }
|}""".stripMargin)

addNewFunction(compareKeyFunc,
case _ @ MapType(keyType, valueType, _) =>
val keyArrayType = ArrayType(keyType)
val valueArrayType = ArrayType(valueType)
val compareFunc = freshName("compareMap")
val funcCode: String =
s"""
|private int $compareKeyFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray,
| int rightIndex) {
| ${javaType(keyType)} left = ${getValue("leftArray", keyType, "leftIndex")};
| ${javaType(keyType)} right = ${getValue("rightArray", keyType, "rightIndex")};
| return ${genComp(keyType, "left", "right")};
|}
|""".stripMargin)
public int $compareFunc(MapData a, MapData b) {
ArrayData aKeys = a.keyArray();
ArrayData bKeys = b.keyArray();
int keyComp = ${genComp(keyArrayType, "aKeys", "bKeys")};
if (keyComp != 0) {
return keyComp;
}

addNewFunction(compareValueFunc,
s"""
|private int $compareValueFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray,
| int rightIndex) {
| $compareElement
|}
|""".stripMargin)
ArrayData aValues = a.valueArray();
ArrayData bValues = b.valueArray();
int valueComp = ${genComp(valueArrayType, "aValues", "bValues")};
if (valueComp != 0) {
return valueComp;
}
return 0;
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"

addNewFunction(compareMapFunc,
s"""
|public int $compareMapFunc(MapData left, MapData right) {
| if (left.numElements() != right.numElements()) {
| return left.numElements() - right.numElements();
| }
|
| int numElements = left.numElements();
| ArrayData leftKeys = left.keyArray();
| ArrayData rightKeys = right.keyArray();
| ArrayData leftValues = left.valueArray();
| ArrayData rightValues = right.valueArray();
|
| Integer[] leftSortedKeyIndex = $initIndexArrayFunc(numElements);
| Integer[] rightSortedKeyIndex = $initIndexArrayFunc(numElements);
| java.util.Arrays.sort(leftSortedKeyIndex, new $keyIndexComparator(leftKeys));
| java.util.Arrays.sort(rightSortedKeyIndex, new $keyIndexComparator(rightKeys));
|
| for (int i = 0; i < numElements; i++) {
| int leftIndex = leftSortedKeyIndex[i];
| int rightIndex = rightSortedKeyIndex[i];
| int keyComp = $compareKeyFunc(leftKeys, leftIndex, rightKeys, rightIndex);
| if (keyComp != 0) {
| return keyComp;
| } else {
| int valueComp = $compareValueFunc(leftValues, leftIndex, rightValues, rightIndex);
| if (valueComp != 0) {
| return valueComp;
| }
| }
| }
| return 0;
|}
|""".stripMargin)
s"this.$compareMapFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -141,24 +141,23 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {

case _ if expr.dataType.isInstanceOf[MapType] =>
val MapType(kt, vt, containsNull) = expr.dataType
var normalized = if (needNormalize(kt)) {
val lv1 = NamedLambdaVariable("arg1", kt, false)
val maybeKeyNormalized = if (needNormalize(kt)) {
val lv1 = NamedLambdaVariable("arg1", kt, nullable = false)
val lv2 = NamedLambdaVariable("arg2", vt, containsNull)
val function = normalize(lv1)
TransformKeys(expr, LambdaFunction(function, Seq(lv1, lv2)))
} else {
expr
}

normalized = if (needNormalize(vt)) {
val lv1 = NamedLambdaVariable("arg1", kt, false)
val maybeKeyValueNormalized = if (needNormalize(vt)) {
val lv1 = NamedLambdaVariable("arg1", kt, nullable = false)
val lv2 = NamedLambdaVariable("arg2", vt, containsNull)
val function = normalize(lv2)
TransformValues(normalized, LambdaFunction(function, Seq(lv1, lv2)))
TransformValues(maybeKeyNormalized, LambdaFunction(function, Seq(lv1, lv2)))
} else {
normalized
maybeKeyNormalized
}
KnownFloatingPointNormalized(normalized)
KnownFloatingPointNormalized(maybeKeyValueNormalized)

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}
Expand Down

This file was deleted.

Loading

0 comments on commit 38e42c4

Please sign in to comment.