-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-34819][SQL] MapType supports orderable semantics #31967
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,7 +17,7 @@ | |
|
||
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} | ||
import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformKeys, TransformValues, UnaryExpression} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} | ||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys | ||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} | ||
|
@@ -95,9 +95,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
case FloatType | DoubleType => true | ||
case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) | ||
case ArrayType(et, _) => needNormalize(et) | ||
// Currently MapType is not comparable and analyzer should fail earlier if this case happens. | ||
case _: MapType => | ||
throw new IllegalStateException("grouping/join/window partition keys cannot be map type.") | ||
case MapType(kt, vt, _) => needNormalize(kt) || needNormalize(vt) | ||
case _ => false | ||
} | ||
|
||
|
@@ -141,6 +139,27 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { | |
val function = normalize(lv) | ||
KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) | ||
|
||
case _ if expr.dataType.isInstanceOf[MapType] => | ||
val MapType(kt, vt, containsNull) = expr.dataType | ||
var normalized = if (needNormalize(kt)) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you avoid to use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add tests for this new code path in |
||
val lv1 = NamedLambdaVariable("arg1", kt, false) | ||
val lv2 = NamedLambdaVariable("arg2", vt, containsNull) | ||
val function = normalize(lv1) | ||
TransformKeys(expr, LambdaFunction(function, Seq(lv1, lv2))) | ||
} else { | ||
expr | ||
} | ||
|
||
normalized = if (needNormalize(vt)) { | ||
val lv1 = NamedLambdaVariable("arg1", kt, false) | ||
val lv2 = NamedLambdaVariable("arg2", vt, containsNull) | ||
val function = normalize(lv2) | ||
TransformValues(normalized, LambdaFunction(function, Seq(lv1, lv2))) | ||
} else { | ||
normalized | ||
} | ||
KnownFloatingPointNormalized(normalized) | ||
|
||
case _ => throw new IllegalStateException(s"fail to normalize $expr") | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one or more | ||
* contributor license agreements. See the NOTICE file distributed with | ||
* this work for additional information regarding copyright ownership. | ||
* The ASF licenses this file to You under the Apache License, Version 2.0 | ||
* (the "License"); you may not use this file except in compliance with | ||
* the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
package org.apache.spark.sql.catalyst.optimizer | ||
|
||
import scala.math.Ordering | ||
|
||
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, ExpectsInputTypes, Expression, UnaryExpression} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext | ||
import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.{getValue, javaType} | ||
import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode | ||
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys | ||
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} | ||
import org.apache.spark.sql.catalyst.rules.Rule | ||
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, MapData, TypeUtils} | ||
import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} | ||
|
||
/** | ||
* When comparing two maps, we have to make sure two maps have the same key value pairs but | ||
* with different key ordering are equal. | ||
* For example, Map('a' -> 1, 'b' -> 2) equals to Map('b' -> 2, 'a' -> 1). | ||
* | ||
* We have to specially handle this in grouping/join/window because Spark SQL turns | ||
* grouping/join/window partition keys into binary `UnsafeRow` and compare the | ||
* binary data directly instead of using MapType's ordering. So in these cases, we have | ||
* to insert an expression to sort map entries by key. | ||
* | ||
* Note that, this rule must be executed at the end of optimizer, because the optimizer may create | ||
* new joins(the subquery rewrite) and new join conditions(the join reorder). | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you leave some comments about why this rule does not handle the Aggregate cases? https://github.com/apache/spark/pull/31967/files#diff-21f071d73070b8257ad76e6e16ec5ed38a13d1278fe94bd42546c258a69f4410R344 |
||
*/ | ||
object NormalizeMapType extends Rule[LogicalPlan] { | ||
def apply(plan: LogicalPlan): LogicalPlan = plan transform { | ||
case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You didn't support |
||
w.copy(partitionSpec = w.partitionSpec.map(normalize)) | ||
|
||
case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) | ||
// The analyzer guarantees left and right joins keys are of the same data type. | ||
if leftKeys.exists(k => needNormalize(k)) => | ||
val newLeftJoinKeys = leftKeys.map(normalize) | ||
val newRightJoinKeys = rightKeys.map(normalize) | ||
val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { | ||
case (l, r) => EqualTo(l, r) | ||
} ++ condition | ||
j.copy(condition = Some(newConditions.reduce(And))) | ||
} | ||
|
||
private def needNormalize(expr: Expression): Boolean = expr match { | ||
case SortMapKey(_) => false | ||
case e if e.dataType.isInstanceOf[MapType] => true | ||
case _ => false | ||
} | ||
|
||
private[sql] def normalize(expr: Expression): Expression = expr match { | ||
case _ if !needNormalize(expr) => expr | ||
case e if e.dataType.isInstanceOf[MapType] => | ||
SortMapKey(e) | ||
} | ||
} | ||
|
||
case class SortMapKey(child: Expression) extends UnaryExpression with ExpectsInputTypes { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
private lazy val MapType(keyType, valueType, valueContainsNull) = dataType.asInstanceOf[MapType] | ||
private lazy val keyOrdering: Ordering[Any] = TypeUtils.getInterpretedOrdering(keyType) | ||
private lazy val mapBuilder = new ArrayBasedMapBuilder(keyType, valueType) | ||
|
||
override def inputTypes: Seq[AbstractDataType] = Seq(MapType) | ||
|
||
override def dataType: DataType = child.dataType | ||
|
||
override def nullSafeEval(input: Any): Any = { | ||
val childMap = input.asInstanceOf[MapData] | ||
val keys = childMap.keyArray() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't need to sort data recursively just for nested case like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that I missed this case. I'll fix it |
||
val values = childMap.valueArray() | ||
val sortedKeyIndex = (0 until childMap.numElements()).toArray.sorted(new Ordering[Int] { | ||
override def compare(a: Int, b: Int): Int = { | ||
keyOrdering.compare(keys.get(a, keyType), keys.get(b, keyType)) | ||
} | ||
}) | ||
|
||
var i = 0 | ||
while (i < childMap.numElements()) { | ||
val index = sortedKeyIndex(i) | ||
mapBuilder.put( | ||
keys.get(index, keyType), | ||
if (values.isNullAt(index)) null else values.get(index, valueType)) | ||
|
||
i += 1 | ||
} | ||
|
||
mapBuilder.build() | ||
} | ||
|
||
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. To make this PR simpler, how about leaving the codegen support into follow-up PRs just like the original PR? https://github.com/apache/spark/pull/15970/files#diff-da163d97a5f0fc534aad719c4a39eca97116f25bfc05b7d8941b342a3ed96036R423-R429 |
||
val initIndexArrayFunc = ctx.freshName("initIndexArray") | ||
val numElements = ctx.freshName("numElements") | ||
val sortedKeyIndex = ctx.freshName("sortedKeyIndex") | ||
val keyArray = ctx.freshName("keyArray") | ||
val valueArray = ctx.freshName("valueArray") | ||
val idx = ctx.freshName("idx") | ||
val index = ctx.freshName("index") | ||
val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) | ||
ctx.addNewFunction(initIndexArrayFunc, | ||
s""" | ||
|private Integer[] $initIndexArrayFunc(int n) { | ||
| Integer[] arr = new Integer[n]; | ||
| for (int i = 0; i < n; i++) { | ||
| arr[i] = i; | ||
| } | ||
| return arr; | ||
|}""".stripMargin) | ||
|
||
val codeToNormalize = (f: String) => { | ||
s""" | ||
|int $numElements = $f.numElements(); | ||
|Integer[] $sortedKeyIndex = $initIndexArrayFunc($numElements); | ||
|final ArrayData $keyArray = $f.keyArray(); | ||
|final ArrayData $valueArray = $f.valueArray(); | ||
|java.util.Arrays.sort($sortedKeyIndex, new java.util.Comparator<Integer>() { | ||
| @Override | ||
| public int compare(Object a, Object b) { | ||
| int indexA = ((Integer)a).intValue(); | ||
| int indexB = ((Integer)b).intValue(); | ||
| ${javaType(keyType)} keyA = ${getValue(keyArray, keyType, "indexA")}; | ||
| ${javaType(keyType)} keyB = ${getValue(keyArray, keyType, "indexB")}; | ||
| return ${ctx.genComp(keyType, "keyA", "keyB")}; | ||
| } | ||
|}); | ||
| | ||
|for (int $idx = 0; $idx < $numElements; $idx++) { | ||
| Integer $index = $sortedKeyIndex[$idx]; | ||
| $builderTerm.put( | ||
| ${getValue(keyArray, keyType, index)}, | ||
| $valueArray.isNullAt($index) ? null : ${getValue(valueArray, valueType, index)}); | ||
|} | ||
| | ||
|${ev.value} = $builderTerm.build(); | ||
|""".stripMargin | ||
} | ||
|
||
nullSafeCodeGen(ctx, ev, codeToNormalize) | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -230,10 +230,12 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
ColumnPruning, | ||
CollapseProject, | ||
RemoveNoopOperators) :+ | ||
// This batch must be executed after the `RewriteSubquery` batch, which creates joins. | ||
// Following batches must be executed after the `RewriteSubquery` batch, which creates joins. | ||
Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ | ||
Batch("NormalizeMapType", Once, NormalizeMapType) :+ | ||
Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) | ||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: unnecessary change. |
||
// remove any batches with no rules. this may happen when subclasses do not add optional rules. | ||
batches.filter(_.rules.nonEmpty) | ||
} | ||
|
@@ -266,7 +268,8 @@ abstract class Optimizer(catalogManager: CatalogManager) | |
RewriteCorrelatedScalarSubquery.ruleName :: | ||
RewritePredicateSubquery.ruleName :: | ||
NormalizeFloatingNumbers.ruleName :: | ||
ReplaceUpdateFieldsExpression.ruleName :: Nil | ||
ReplaceUpdateFieldsExpression.ruleName :: | ||
NormalizeMapType.ruleName :: Nil | ||
|
||
/** | ||
* Optimize all the subqueries inside expression. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's a difference from the @hvanhovell impl.? The @hvanhovell one looks simpler though.
https://github.com/apache/spark/pull/15970/files#diff-1501206e78d34b65183af1092c8ec392ce18574bb538f905ca93a22983c63ae6R558-R598
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Btw, we cannot reuse the Array case? https://github.com/apache/spark/pull/31967/files#diff-1501206e78d34b65183af1092c8ec392ce18574bb538f905ca93a22983c63ae6R643