Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-34819][SQL] MapType supports orderable semantics #31967

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -639,14 +639,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog {
|Conflicting attributes: ${conflictingAttributes.mkString(",")}
""".stripMargin)

// TODO: although map type is not orderable, technically map type should be able to be
// used in equality comparison, remove this type check once we support it.
case o if mapColumnInSetOperation(o).isDefined =>
val mapCol = mapColumnInSetOperation(o).get
failAnalysis("Cannot have map type columns in DataFrame which calls " +
s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " +
"is " + mapCol.dataType.catalogString)

case o if o.expressions.exists(!_.deterministic) &&
!o.isInstanceOf[Project] && !o.isInstanceOf[Filter] &&
!o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,7 @@ class CodegenContext extends Logging {
case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)"
case array: ArrayType => genComp(array, c1, c2) + " == 0"
case struct: StructType => genComp(struct, c1, c2) + " == 0"
case map: MapType => genComp(map, c1, c2) + " == 0"
case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2)
case NullType => "false"
case _ =>
Expand Down Expand Up @@ -687,6 +688,118 @@ class CodegenContext extends Logging {
}
"""
s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)"
case _ @ MapType(keyType, valueType, valueContainsNull) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

val compareMapFunc = freshName("compareMap")
val initIndexArrayFunc = freshName("initIndexArray")
val keyIndexComparator = freshName("keyIndexComparator")
val compareKeyFunc = freshName("compareKey")
val compareValueFunc = freshName("compareValue")
val nullSafeCompare =
s"""
|${javaType(valueType)} left = ${getValue("leftArray", valueType, "leftIndex")};
|${javaType(valueType)} right = ${getValue("rightArray", valueType, "rightIndex")};
|return ${genComp(valueType, "left", "right")};
|""".stripMargin
val compareElement = if (valueContainsNull) {
s"""
|boolean isNullA = leftArray.isNullAt(leftIndex);
|boolean isNullB = rightArray.isNullAt(rightIndex);
|if (isNullA && isNullB) {
| return 0;
|} else if (isNullA) {
| return -1;
|} else if (isNullB) {
| return 1;
|} else {
| $nullSafeCompare
|}
|""".stripMargin
} else {
nullSafeCompare
}

addNewFunction(initIndexArrayFunc,
s"""
|private Integer[] $initIndexArrayFunc(int n) {
| Integer[] arr = new Integer[n];
| for (int i = 0; i < n; i++) {
| arr[i] = i;
| }
| return arr;
|}""".stripMargin)


addNewFunction(keyIndexComparator,
s"""
|private class $keyIndexComparator implements java.util.Comparator<Integer> {
| private ArrayData array;
| public $keyIndexComparator(ArrayData array) {
| this.array = array;
| }
|
| @Override
| public int compare(Object a, Object b) {
| int indexA = ((Integer)a).intValue();
| int indexB = ((Integer)b).intValue();
| ${javaType(keyType)} keyA = ${getValue("array", keyType, "indexA")};
| ${javaType(keyType)} keyB = ${getValue("array", keyType, "indexB")};
| return ${genComp(keyType, "keyA", "keyB")};
| }
|}""".stripMargin)

addNewFunction(compareKeyFunc,
s"""
|private int $compareKeyFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray,
| int rightIndex) {
| ${javaType(keyType)} left = ${getValue("leftArray", keyType, "leftIndex")};
| ${javaType(keyType)} right = ${getValue("rightArray", keyType, "rightIndex")};
| return ${genComp(keyType, "left", "right")};
|}
|""".stripMargin)

addNewFunction(compareValueFunc,
s"""
|private int $compareValueFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray,
| int rightIndex) {
| $compareElement
|}
|""".stripMargin)

addNewFunction(compareMapFunc,
s"""
|public int $compareMapFunc(MapData left, MapData right) {
| if (left.numElements() != right.numElements()) {
| return left.numElements() - right.numElements();
| }
|
| int numElements = left.numElements();
| ArrayData leftKeys = left.keyArray();
| ArrayData rightKeys = right.keyArray();
| ArrayData leftValues = left.valueArray();
| ArrayData rightValues = right.valueArray();
|
| Integer[] leftSortedKeyIndex = $initIndexArrayFunc(numElements);
| Integer[] rightSortedKeyIndex = $initIndexArrayFunc(numElements);
| java.util.Arrays.sort(leftSortedKeyIndex, new $keyIndexComparator(leftKeys));
| java.util.Arrays.sort(rightSortedKeyIndex, new $keyIndexComparator(rightKeys));
|
| for (int i = 0; i < numElements; i++) {
| int leftIndex = leftSortedKeyIndex[i];
| int rightIndex = rightSortedKeyIndex[i];
| int keyComp = $compareKeyFunc(leftKeys, leftIndex, rightKeys, rightIndex);
| if (keyComp != 0) {
| return keyComp;
| } else {
| int valueComp = $compareValueFunc(leftValues, leftIndex, rightValues, rightIndex);
| if (valueComp != 0) {
| return valueComp;
| }
| }
| }
| return 0;
|}
|""".stripMargin)
s"this.$compareMapFunc($c1, $c2)"
case schema: StructType =>
val comparisons = GenerateOrdering.genComparisons(this, schema)
val compareFunc = freshName("compareStruct")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends BaseOrdering {
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case a: ArrayType if order.direction == Descending =>
- a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case a: MapType if order.direction == Ascending =>
a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case a: MapType if order.direction == Descending =>
- a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Ascending =>
s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right)
case s: StructType if order.direction == Descending =>
Expand Down Expand Up @@ -104,6 +108,7 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder],
case dt: AtomicType => true
case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType))
case array: ArrayType => isOrderable(array.elementType)
case map: MapType => isOrderable(map.keyType) && isOrderable(map.valueType)
case udt: UserDefinedType[_] => isOrderable(udt.sqlType)
case _ => false
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.catalyst.optimizer

import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformKeys, TransformValues, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
Expand Down Expand Up @@ -95,9 +95,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
case FloatType | DoubleType => true
case StructType(fields) => fields.exists(f => needNormalize(f.dataType))
case ArrayType(et, _) => needNormalize(et)
// Currently MapType is not comparable and analyzer should fail earlier if this case happens.
case _: MapType =>
throw new IllegalStateException("grouping/join/window partition keys cannot be map type.")
case MapType(kt, vt, _) => needNormalize(kt) || needNormalize(vt)
case _ => false
}

Expand Down Expand Up @@ -141,6 +139,27 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] {
val function = normalize(lv)
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv))))

case _ if expr.dataType.isInstanceOf[MapType] =>
val MapType(kt, vt, containsNull) = expr.dataType
var normalized = if (needNormalize(kt)) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you avoid to use var here?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add tests for this new code path in NormalizeFloatingPointNumbersSuite?

val lv1 = NamedLambdaVariable("arg1", kt, false)
val lv2 = NamedLambdaVariable("arg2", vt, containsNull)
val function = normalize(lv1)
TransformKeys(expr, LambdaFunction(function, Seq(lv1, lv2)))
} else {
expr
}

normalized = if (needNormalize(vt)) {
val lv1 = NamedLambdaVariable("arg1", kt, false)
val lv2 = NamedLambdaVariable("arg2", vt, containsNull)
val function = normalize(lv2)
TransformValues(normalized, LambdaFunction(function, Seq(lv1, lv2)))
} else {
normalized
}
KnownFloatingPointNormalized(normalized)

case _ => throw new IllegalStateException(s"fail to normalize $expr")
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.catalyst.optimizer

import scala.math.Ordering

import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, ExpectsInputTypes, Expression, UnaryExpression}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.{getValue, javaType}
import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, MapData, TypeUtils}
import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType}

/**
* When comparing two maps, we have to make sure two maps have the same key value pairs but
* with different key ordering are equal.
* For example, Map('a' -> 1, 'b' -> 2) equals to Map('b' -> 2, 'a' -> 1).
*
* We have to specially handle this in grouping/join/window because Spark SQL turns
* grouping/join/window partition keys into binary `UnsafeRow` and compare the
* binary data directly instead of using MapType's ordering. So in these cases, we have
* to insert an expression to sort map entries by key.
*
* Note that, this rule must be executed at the end of optimizer, because the optimizer may create
* new joins(the subquery rewrite) and new join conditions(the join reorder).
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you leave some comments about why this rule does not handle the Aggregate cases? https://github.com/apache/spark/pull/31967/files#diff-21f071d73070b8257ad76e6e16ec5ed38a13d1278fe94bd42546c258a69f4410R344

*/
object NormalizeMapType extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) =>
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You didn't support BinaryComparison cases?

w.copy(partitionSpec = w.partitionSpec.map(normalize))

case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _)
// The analyzer guarantees left and right joins keys are of the same data type.
if leftKeys.exists(k => needNormalize(k)) =>
val newLeftJoinKeys = leftKeys.map(normalize)
val newRightJoinKeys = rightKeys.map(normalize)
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map {
case (l, r) => EqualTo(l, r)
} ++ condition
j.copy(condition = Some(newConditions.reduce(And)))
}

private def needNormalize(expr: Expression): Boolean = expr match {
case SortMapKey(_) => false
case e if e.dataType.isInstanceOf[MapType] => true
case _ => false
}

private[sql] def normalize(expr: Expression): Expression = expr match {
case _ if !needNormalize(expr) => expr
case e if e.dataType.isInstanceOf[MapType] =>
SortMapKey(e)
}
}

case class SortMapKey(child: Expression) extends UnaryExpression with ExpectsInputTypes {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SortMapKey -> SortMapKeys?

private lazy val MapType(keyType, valueType, valueContainsNull) = dataType.asInstanceOf[MapType]
private lazy val keyOrdering: Ordering[Any] = TypeUtils.getInterpretedOrdering(keyType)
private lazy val mapBuilder = new ArrayBasedMapBuilder(keyType, valueType)

override def inputTypes: Seq[AbstractDataType] = Seq(MapType)

override def dataType: DataType = child.dataType

override def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to sort data recursively just for nested case like map<map<int,int>,string> and map<struct<a: map<int,int>>,string>)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that I missed this case. I'll fix it

val values = childMap.valueArray()
val sortedKeyIndex = (0 until childMap.numElements()).toArray.sorted(new Ordering[Int] {
override def compare(a: Int, b: Int): Int = {
keyOrdering.compare(keys.get(a, keyType), keys.get(b, keyType))
}
})

var i = 0
while (i < childMap.numElements()) {
val index = sortedKeyIndex(i)
mapBuilder.put(
keys.get(index, keyType),
if (values.isNullAt(index)) null else values.get(index, valueType))

i += 1
}

mapBuilder.build()
}

override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To make this PR simpler, how about leaving the codegen support into follow-up PRs just like the original PR? https://github.com/apache/spark/pull/15970/files#diff-da163d97a5f0fc534aad719c4a39eca97116f25bfc05b7d8941b342a3ed96036R423-R429

val initIndexArrayFunc = ctx.freshName("initIndexArray")
val numElements = ctx.freshName("numElements")
val sortedKeyIndex = ctx.freshName("sortedKeyIndex")
val keyArray = ctx.freshName("keyArray")
val valueArray = ctx.freshName("valueArray")
val idx = ctx.freshName("idx")
val index = ctx.freshName("index")
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder)
ctx.addNewFunction(initIndexArrayFunc,
s"""
|private Integer[] $initIndexArrayFunc(int n) {
| Integer[] arr = new Integer[n];
| for (int i = 0; i < n; i++) {
| arr[i] = i;
| }
| return arr;
|}""".stripMargin)

val codeToNormalize = (f: String) => {
s"""
|int $numElements = $f.numElements();
|Integer[] $sortedKeyIndex = $initIndexArrayFunc($numElements);
|final ArrayData $keyArray = $f.keyArray();
|final ArrayData $valueArray = $f.valueArray();
|java.util.Arrays.sort($sortedKeyIndex, new java.util.Comparator<Integer>() {
| @Override
| public int compare(Object a, Object b) {
| int indexA = ((Integer)a).intValue();
| int indexB = ((Integer)b).intValue();
| ${javaType(keyType)} keyA = ${getValue(keyArray, keyType, "indexA")};
| ${javaType(keyType)} keyB = ${getValue(keyArray, keyType, "indexB")};
| return ${ctx.genComp(keyType, "keyA", "keyB")};
| }
|});
|
|for (int $idx = 0; $idx < $numElements; $idx++) {
| Integer $index = $sortedKeyIndex[$idx];
| $builderTerm.put(
| ${getValue(keyArray, keyType, index)},
| $valueArray.isNullAt($index) ? null : ${getValue(valueArray, valueType, index)});
|}
|
|${ev.value} = $builderTerm.build();
|""".stripMargin
}

nullSafeCodeGen(ctx, ev, codeToNormalize)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,12 @@ abstract class Optimizer(catalogManager: CatalogManager)
ColumnPruning,
CollapseProject,
RemoveNoopOperators) :+
// This batch must be executed after the `RewriteSubquery` batch, which creates joins.
// Following batches must be executed after the `RewriteSubquery` batch, which creates joins.
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+
Batch("NormalizeMapType", Once, NormalizeMapType) :+
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression)


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: unnecessary change.

// remove any batches with no rules. this may happen when subclasses do not add optional rules.
batches.filter(_.rules.nonEmpty)
}
Expand Down Expand Up @@ -266,7 +268,8 @@ abstract class Optimizer(catalogManager: CatalogManager)
RewriteCorrelatedScalarSubquery.ruleName ::
RewritePredicateSubquery.ruleName ::
NormalizeFloatingNumbers.ruleName ::
ReplaceUpdateFieldsExpression.ruleName :: Nil
ReplaceUpdateFieldsExpression.ruleName ::
NormalizeMapType.ruleName :: Nil

/**
* Optimize all the subqueries inside expression.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ object TypeUtils {
t match {
case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]]
case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case m: MapType => m.interpretedOrdering.asInstanceOf[Ordering[Any]]
case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType)
}
Expand Down
Loading