From 7d160980424b8d00491bf1d028dc1e6a3490912a Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Tue, 9 Mar 2021 22:23:38 +0800 Subject: [PATCH 1/5] MapType supports comparable/orderable semantics --- .../sql/catalyst/analysis/CheckAnalysis.scala | 8 - .../expressions/codegen/CodeGenerator.scala | 108 +++++++++++ .../sql/catalyst/expressions/ordering.scala | 5 + .../optimizer/NormalizeFloatingNumbers.scala | 27 ++- .../catalyst/optimizer/NormalizeMapType.scala | 179 ++++++++++++++++++ .../sql/catalyst/optimizer/Optimizer.scala | 7 +- .../spark/sql/catalyst/util/TypeUtils.scala | 1 + .../org/apache/spark/sql/types/MapType.scala | 66 +++++++ .../analysis/AnalysisErrorSuite.scala | 23 --- .../ExpressionTypeCheckingSuite.scala | 15 -- .../catalyst/expressions/PredicateSuite.scala | 8 - .../spark/sql/DataFrameAggregateSuite.scala | 28 ++- .../spark/sql/DataFrameFunctionsSuite.scala | 5 - .../sql/DataFrameSetOperationsSuite.scala | 35 ++-- .../sql/DataFrameWindowFunctionsSuite.scala | 12 ++ .../org/apache/spark/sql/JoinSuite.scala | 20 +- .../org/apache/spark/sql/SQLQuerySuite.scala | 140 +++++++++++++- .../sql/sources/BucketedWriteSuite.scala | 5 - 18 files changed, 590 insertions(+), 102 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index c699942ab55ca..7ef81593e2c56 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -639,14 +639,6 @@ trait CheckAnalysis extends PredicateHelper with LookupCatalog { |Conflicting attributes: ${conflictingAttributes.mkString(",")} """.stripMargin) - // TODO: although map type is not orderable, technically map type should be able to be - // used in equality comparison, remove this type check once we support it. - case o if mapColumnInSetOperation(o).isDefined => - val mapCol = mapColumnInSetOperation(o).get - failAnalysis("Cannot have map type columns in DataFrame which calls " + - s"set operations(intersect, except, etc.), but the type of column ${mapCol.name} " + - "is " + mapCol.dataType.catalogString) - case o if o.expressions.exists(!_.deterministic) && !o.isInstanceOf[Project] && !o.isInstanceOf[Filter] && !o.isInstanceOf[Aggregate] && !o.isInstanceOf[Window] => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index e01f7dc16a663..bf872b29e195d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -612,6 +612,7 @@ class CodegenContext extends Logging { case dt: DataType if dt.isInstanceOf[AtomicType] => s"$c1.equals($c2)" case array: ArrayType => genComp(array, c1, c2) + " == 0" case struct: StructType => genComp(struct, c1, c2) + " == 0" + case map: MapType => genComp(map, c1, c2) + " == 0" case udt: UserDefinedType[_] => genEqual(udt.sqlType, c1, c2) case NullType => "false" case _ => @@ -687,6 +688,113 @@ class CodegenContext extends Logging { } """ s"${addNewFunction(compareFunc, funcCode)}($c1, $c2)" + case _ @ MapType(keyType, valueType, valueContainsNull) => + val compareMapFunc = freshName("compareMap") + val initIndexArrayFunc = freshName("initIndexArray") + val keyIndexComparator = freshName("keyIndexComparator") + val compareKeyFunc = freshName("compareKey") + val compareValueFunc = freshName("compareValue") + val nullValueCheck = if (valueContainsNull) { + s""" + |boolean isNullA = leftArray.isNullAt(leftIndex); + |boolean isNullB = rightArray.isNullAt(rightIndex); + |if (isNullA && isNullB) { + | // do nothing + |} else if (isNullA) { + | return -1; + |} else if (isNullB) { + | return 1; + |} + |""".stripMargin + } else { + "" + } + + addNewFunction(initIndexArrayFunc, + s""" + |private Integer[] $initIndexArrayFunc(int n) { + | Integer[] arr = new Integer[n]; + | for (int i = 0; i < n; i++) { + | arr[i] = i; + | } + | return arr; + |}""".stripMargin) + + + addNewFunction(keyIndexComparator, + s""" + |private class $keyIndexComparator implements java.util.Comparator { + | private ArrayData array; + | public $keyIndexComparator(ArrayData array) { + | this.array = array; + | } + | + | @Override + | public int compare(Object a, Object b) { + | Integer indexA = (Integer)a; + | Integer indexB = (Integer)b; + | ${javaType(keyType)} keyA = ${getValue("array", keyType, "indexA")}; + | ${javaType(keyType)} keyB = ${getValue("array", keyType, "indexB")}; + | return ${genComp(keyType, "keyA", "keyB")}; + | } + |}""".stripMargin) + + addNewFunction(compareKeyFunc, + s""" + |private int $compareKeyFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray, + | int rightIndex) { + | ${javaType(keyType)} left = ${getValue("leftArray", keyType, "leftIndex")}; + | ${javaType(keyType)} right = ${getValue("rightArray", keyType, "rightIndex")}; + | return ${genComp(keyType, "left", "right")}; + |} + |""".stripMargin) + + addNewFunction(compareValueFunc, + s""" + |private int $compareValueFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray, + | int rightIndex) { + | $nullValueCheck + | ${javaType(valueType)} left = ${getValue("leftArray", valueType, "leftIndex")}; + | ${javaType(valueType)} right = ${getValue("rightArray", valueType, "rightIndex")}; + | return ${genComp(valueType, "left", "right")}; + |} + |""".stripMargin) + + addNewFunction(compareMapFunc, + s""" + |public int $compareMapFunc(MapData left, MapData right) { + | if (left.numElements() != right.numElements()) { + | return left.numElements() - right.numElements(); + | } + | + | int numElements = left.numElements(); + | ArrayData leftKeys = left.keyArray(); + | ArrayData rightKeys = right.keyArray(); + | ArrayData leftValues = left.valueArray(); + | ArrayData rightValues = right.valueArray(); + | + | Integer[] leftSortedKeyIndex = $initIndexArrayFunc(numElements); + | Integer[] rightSortedKeyIndex = $initIndexArrayFunc(numElements); + | java.util.Arrays.sort(leftSortedKeyIndex, new $keyIndexComparator(leftKeys)); + | java.util.Arrays.sort(rightSortedKeyIndex, new $keyIndexComparator(rightKeys)); + | + | for (int i = 0; i < numElements; i++) { + | int leftIndex = leftSortedKeyIndex[i]; + | int rightIndex = rightSortedKeyIndex[i]; + | int keyComp = $compareKeyFunc(leftKeys, leftIndex, rightKeys, rightIndex); + | if (keyComp != 0) { + | return keyComp; + | } else { + | int valueComp = $compareValueFunc(leftValues, leftIndex, rightValues, rightIndex); + | if (valueComp != 0) { + | return valueComp; + | } + | } + | } + | return 0; + |} + |""".stripMargin) + s"this.$compareMapFunc($c1, $c2)" case schema: StructType => val comparisons = GenerateOrdering.genComparisons(this, schema) val compareFunc = freshName("compareStruct") diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala index ba3ed02e06ef1..9f7639ec235d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ordering.scala @@ -65,6 +65,10 @@ class InterpretedOrdering(ordering: Seq[SortOrder]) extends BaseOrdering { a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case a: ArrayType if order.direction == Descending => - a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: MapType if order.direction == Ascending => + a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) + case a: MapType if order.direction == Descending => + - a.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Ascending => s.interpretedOrdering.asInstanceOf[Ordering[Any]].compare(left, right) case s: StructType if order.direction == Descending => @@ -104,6 +108,7 @@ object RowOrdering extends CodeGeneratorWithInterpretedFallback[Seq[SortOrder], case dt: AtomicType => true case struct: StructType => struct.fields.forall(f => isOrderable(f.dataType)) case array: ArrayType => isOrderable(array.elementType) + case map: MapType => isOrderable(map.keyType) && isOrderable(map.valueType) case udt: UserDefinedType[_] => isOrderable(udt.sqlType) case _ => false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index ac8766cd74367..f8efd65cca970 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.catalyst.optimizer -import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, ArrayTransform, CaseWhen, Coalesce, CreateArray, CreateMap, CreateNamedStruct, EqualTo, ExpectsInputTypes, Expression, GetStructField, If, IsNull, KnownFloatingPointNormalized, LambdaFunction, Literal, NamedLambdaVariable, TransformKeys, TransformValues, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} @@ -95,9 +95,7 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { case FloatType | DoubleType => true case StructType(fields) => fields.exists(f => needNormalize(f.dataType)) case ArrayType(et, _) => needNormalize(et) - // Currently MapType is not comparable and analyzer should fail earlier if this case happens. - case _: MapType => - throw new IllegalStateException("grouping/join/window partition keys cannot be map type.") + case MapType(kt, vt, _) => needNormalize(kt) || needNormalize(vt) case _ => false } @@ -141,6 +139,27 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { val function = normalize(lv) KnownFloatingPointNormalized(ArrayTransform(expr, LambdaFunction(function, Seq(lv)))) + case _ if expr.dataType.isInstanceOf[MapType] => + val MapType(kt, vt, containsNull) = expr.dataType + var normalized = if (needNormalize(kt)) { + val lv1 = NamedLambdaVariable("arg1", kt, false) + val lv2 = NamedLambdaVariable("arg2", vt, containsNull) + val function = normalize(lv1) + TransformKeys(expr, LambdaFunction(function, Seq(lv1, lv2))) + } else { + expr + } + + normalized = if (needNormalize(vt)) { + val lv1 = NamedLambdaVariable("arg1", kt, false) + val lv2 = NamedLambdaVariable("arg2", vt, containsNull) + val function = normalize(lv2) + TransformValues(normalized, LambdaFunction(function, Seq(lv1, lv2))) + } else { + normalized + } + KnownFloatingPointNormalized(normalized) + case _ => throw new IllegalStateException(s"fail to normalize $expr") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala new file mode 100644 index 0000000000000..67093245a7684 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala @@ -0,0 +1,179 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.math.Ordering + +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, EqualTo, ExpectsInputTypes, Expression, NamedExpression, NamedLambdaVariable, TaggingExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext +import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.{getValue, javaType} +import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode +import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LogicalPlan, Project, Window} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, MapData, TypeUtils} +import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} + +case class KeyOrderedMap(child: Expression) extends TaggingExpression + +/** + * Spark SQL turns grouping/join/window partition keys into binary `UnsafeRow` and compare the + * binary data directly instead of using MapType's ordering. So in order to make sure two maps + * have the same key value pairs but with different key ordering generate right result, we have + * to insert an expression to sort map entries by key. + * + * Note that, this rule must be executed at the end of optimizer, because the optimizer may create + * new joins(the subquery rewrite) and new join conditions(the join reorder). + */ +object NormalizeMapType extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + case w: Window if w.partitionSpec.exists(p => needNormalize(p)) => + w.copy(partitionSpec = w.partitionSpec.map(normalize)) + + case j @ ExtractEquiJoinKeys(_, leftKeys, rightKeys, condition, _, _, _) + // The analyzer guarantees left and right joins keys are of the same data type. + if leftKeys.exists(k => needNormalize(k)) => + val newLeftJoinKeys = leftKeys.map(normalize) + val newRightJoinKeys = rightKeys.map(normalize) + val newConditions = newLeftJoinKeys.zip(newRightJoinKeys).map { + case (l, r) => EqualTo(l, r) + } ++ condition + j.copy(condition = Some(newConditions.reduce(And))) + + case agg: Aggregate if agg.aggregateExpressions.exists(needNormalize) => + val replacements = agg.groupingExpressions.collect { + case e if needNormalize(e) => e + } + + agg.transformExpressionsUp { + case e => + replacements + .find(_.semanticEquals(e)) + .map(_ => normalize(e)) + .getOrElse(e) + } + + case Distinct(child) if child.output.exists(needNormalize) => + val projectList = child.output.map(normalize).asInstanceOf[Seq[NamedExpression]] + Distinct(Project(projectList, child)) + } + + private def needNormalize(expr: Expression): Boolean = expr match { + case ReorderMapKey(_) => false + case Alias(ReorderMapKey(_), _) => false + case e if e.dataType.isInstanceOf[MapType] => true + case _ => false + } + + private[sql] def normalize(expr: Expression): Expression = expr match { + case _ if !needNormalize(expr) => expr + case a: Attribute if a.dataType.isInstanceOf[MapType] => + val newAttr = a.withExprId(NamedExpression.newExprId) + Alias(ReorderMapKey(newAttr), a.name)(exprId = a.exprId, qualifier = a.qualifier) + case a: Alias => + a.withNewChildren(Seq(ReorderMapKey(a.child))) + case a: NamedLambdaVariable => + val newNLV = a.copy(exprId = NamedExpression.newExprId) + Alias(ReorderMapKey(newNLV), a.name)(exprId = a.exprId, qualifier = a.qualifier) + case e if e.dataType.isInstanceOf[MapType] => + ReorderMapKey(e) + } +} + +case class ReorderMapKey(child: Expression) extends UnaryExpression with ExpectsInputTypes { + private lazy val MapType(keyType, valueType, valueContainsNull) = dataType.asInstanceOf[MapType] + private lazy val keyOrdering: Ordering[Any] = TypeUtils.getInterpretedOrdering(keyType) + private lazy val mapBuilder = new ArrayBasedMapBuilder(keyType, valueType) + + override def inputTypes: Seq[AbstractDataType] = Seq(MapType) + + override def dataType: DataType = child.dataType + + override def nullSafeEval(input: Any): Any = { + val childMap = input.asInstanceOf[MapData] + val keys = childMap.keyArray() + val values = childMap.valueArray() + val sortedKeyIndex = (0 until childMap.numElements()).toArray.sorted(new Ordering[Int] { + override def compare(a: Int, b: Int): Int = { + keyOrdering.compare(keys.get(a, keyType), keys.get(b, keyType)) + } + }) + + var i = 0 + while (i < childMap.numElements()) { + val index = sortedKeyIndex(i) + mapBuilder.put( + keys.get(index, keyType), + if (values.isNullAt(index)) null else values.get(index, valueType)) + + i += 1 + } + + mapBuilder.build() + } + + override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val initIndexArrayFunc = ctx.freshName("initIndexArray") + val numElements = ctx.freshName("numElements") + val sortedKeyIndex = ctx.freshName("sortedKeyIndex") + val keyArray = ctx.freshName("keyArray") + val valueArray = ctx.freshName("valueArray") + val idx = ctx.freshName("idx") + val index = ctx.freshName("index") + val builderTerm = ctx.addReferenceObj("mapBuilder", mapBuilder) + ctx.addNewFunction(initIndexArrayFunc, + s""" + |private Integer[] $initIndexArrayFunc(int n) { + | Integer[] arr = new Integer[n]; + | for (int i = 0; i < n; i++) { + | arr[i] = i; + | } + | return arr; + |}""".stripMargin) + + val codeToNormalize = (f: String) => { + s""" + |int $numElements = $f.numElements(); + |Integer[] $sortedKeyIndex = $initIndexArrayFunc($numElements); + |final ArrayData $keyArray = $f.keyArray(); + |final ArrayData $valueArray = $f.valueArray(); + |java.util.Arrays.sort($sortedKeyIndex, new java.util.Comparator() { + | @Override + | public int compare(Object a, Object b) { + | int indexA = ((Integer)a).intValue(); + | int indexB = ((Integer)b).intValue(); + | ${javaType(keyType)} keyA = ${getValue(keyArray, keyType, "indexA")}; + | ${javaType(keyType)} keyB = ${getValue(keyArray, keyType, "indexB")}; + | return ${ctx.genComp(keyType, "keyA", "keyB")}; + | } + |}); + | + |for (int $idx = 0; $idx < $numElements; $idx++) { + | Integer $index = $sortedKeyIndex[$idx]; + | $builderTerm.put( + | ${getValue(keyArray, keyType, index)}, + | $valueArray.isNullAt($index) ? null : ${getValue(valueArray, valueType, index)}); + |} + | + |${ev.value} = $builderTerm.build(); + |""".stripMargin + } + + nullSafeCodeGen(ctx, ev, codeToNormalize) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3e3550d5da89b..cfdf4bcf27729 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -230,10 +230,12 @@ abstract class Optimizer(catalogManager: CatalogManager) ColumnPruning, CollapseProject, RemoveNoopOperators) :+ - // This batch must be executed after the `RewriteSubquery` batch, which creates joins. + // Following batches must be executed after the `RewriteSubquery` batch, which creates joins. Batch("NormalizeFloatingNumbers", Once, NormalizeFloatingNumbers) :+ + Batch("NormalizeMapType", Once, NormalizeMapType) :+ Batch("ReplaceUpdateFieldsExpression", Once, ReplaceUpdateFieldsExpression) + // remove any batches with no rules. this may happen when subclasses do not add optional rules. batches.filter(_.rules.nonEmpty) } @@ -266,7 +268,8 @@ abstract class Optimizer(catalogManager: CatalogManager) RewriteCorrelatedScalarSubquery.ruleName :: RewritePredicateSubquery.ruleName :: NormalizeFloatingNumbers.ruleName :: - ReplaceUpdateFieldsExpression.ruleName :: Nil + ReplaceUpdateFieldsExpression.ruleName :: + NormalizeMapType.ruleName :: Nil /** * Optimize all the subqueries inside expression. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala index 6212e8f48c04b..eda07f248d5ea 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/TypeUtils.scala @@ -73,6 +73,7 @@ object TypeUtils { t match { case i: AtomicType => i.ordering.asInstanceOf[Ordering[Any]] case a: ArrayType => a.interpretedOrdering.asInstanceOf[Ordering[Any]] + case m: MapType => m.interpretedOrdering.asInstanceOf[Ordering[Any]] case s: StructType => s.interpretedOrdering.asInstanceOf[Ordering[Any]] case udt: UserDefinedType[_] => getInterpretedOrdering(udt.sqlType) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala index 2e5c7f731dcc7..398d401d41752 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/MapType.scala @@ -17,10 +17,13 @@ package org.apache.spark.sql.types +import scala.math.Ordering + import org.json4s.JsonAST.JValue import org.json4s.JsonDSL._ import org.apache.spark.annotation.Stable +import org.apache.spark.sql.catalyst.util.{ArrayData, MapData, TypeUtils} import org.apache.spark.sql.catalyst.util.StringUtils.StringConcat /** @@ -79,6 +82,69 @@ case class MapType( override private[spark] def existsRecursively(f: (DataType) => Boolean): Boolean = { f(this) || keyType.existsRecursively(f) || valueType.existsRecursively(f) } + + @transient + private[sql] lazy val keyOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(keyType) + + @transient + private[sql] lazy val valueOrdering: Ordering[Any] = + TypeUtils.getInterpretedOrdering(valueType) + + @transient + private[sql] lazy val interpretedOrdering: Ordering[MapData] = new Ordering[MapData] { + def compare(left: MapData, right: MapData): Int = { + if (left.numElements() != right.numElements()) { + return left.numElements() - right.numElements() + } + + val numElements = left.numElements() + val leftKeys = left.keyArray() + val rightKeys = right.keyArray() + val leftValues = left.valueArray() + val rightValues = right.valueArray() + + val keyIndexOrdering = (keys: ArrayData) => new Ordering[Int] { + override def compare(a: Int, b: Int): Int = { + keyOrdering.compare(keys.get(a, keyType), keys.get(b, keyType)) + } + } + val leftSortedKeyIndex = (0 until numElements).toArray.sorted(keyIndexOrdering(leftKeys)) + val rightSortedKeyIndex = (0 until numElements).toArray.sorted(keyIndexOrdering(rightKeys)) + var i = 0 + while (i < numElements) { + val leftIndex = leftSortedKeyIndex(i) + val rightIndex = rightSortedKeyIndex(i) + + val leftKey = leftKeys.get(leftIndex, keyType) + val rightKey = rightKeys.get(rightIndex, keyType) + val keyComp = keyOrdering.compare(leftKey, rightKey) + if (keyComp != 0) { + return keyComp + } else { + val leftValueIsNull = leftValues.isNullAt(leftIndex) + val rightValueIsNull = rightValues.isNullAt(rightIndex) + if (leftValueIsNull && rightValueIsNull) { + // do nothing + } else if (leftValueIsNull) { + return -1 + } else if (rightValueIsNull) { + return 1 + } else { + val leftValue = leftValues.get(leftIndex, valueType) + val rightValue = rightValues.get(rightIndex, valueType) + val valueComp = valueOrdering.compare(leftValue, rightValue) + if (valueComp != 0) { + return valueComp + } + } + } + i += 1 + } + + 0 + } + } } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index ec5a9cc9afad5..dec9e8a0da9cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -297,11 +297,6 @@ class AnalysisErrorSuite extends AnalysisTest { testRelation2.groupBy($"a")(sum(UnresolvedStar(None))), "Invalid usage of '*'" :: "in expression 'sum'" :: Nil) - errorTest( - "sorting by unsupported column types", - mapRelation.orderBy($"map".asc), - "sort" :: "type" :: "map" :: Nil) - errorTest( "sorting by attributes are not from grouping expressions", testRelation2.groupBy($"a", $"c")($"a", $"c", count($"a").as("a3")).orderBy($"b".asc), @@ -623,24 +618,6 @@ class AnalysisErrorSuite extends AnalysisTest { "another aggregate function." :: Nil) } - test("Join can work on binary types but can't work on map types") { - val left = LocalRelation(Symbol("a").binary, Symbol("b").map(StringType, StringType)) - val right = LocalRelation(Symbol("c").binary, Symbol("d").map(StringType, StringType)) - - val plan1 = left.join( - right, - joinType = Cross, - condition = Some(Symbol("a") === Symbol("c"))) - - assertAnalysisSuccess(plan1) - - val plan2 = left.join( - right, - joinType = Cross, - condition = Some(Symbol("b") === Symbol("d"))) - assertAnalysisError(plan2, "EqualTo does not support ordering on type map" :: Nil) - } - test("PredicateSubQuery is used outside of a filter") { val a = AttributeReference("a", IntegerType)() val b = AttributeReference("b", IntegerType)() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index 44f333342d1c8..ce49622a3234f 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -113,19 +113,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertErrorForDifferingTypes(GreaterThan(Symbol("intField"), Symbol("booleanField"))) assertErrorForDifferingTypes(GreaterThanOrEqual(Symbol("intField"), Symbol("booleanField"))) - assertError(EqualTo(Symbol("mapField"), Symbol("mapField")), - "EqualTo does not support ordering on type map") - assertError(EqualNullSafe(Symbol("mapField"), Symbol("mapField")), - "EqualNullSafe does not support ordering on type map") - assertError(LessThan(Symbol("mapField"), Symbol("mapField")), - "LessThan does not support ordering on type map") - assertError(LessThanOrEqual(Symbol("mapField"), Symbol("mapField")), - "LessThanOrEqual does not support ordering on type map") - assertError(GreaterThan(Symbol("mapField"), Symbol("mapField")), - "GreaterThan does not support ordering on type map") - assertError(GreaterThanOrEqual(Symbol("mapField"), Symbol("mapField")), - "GreaterThanOrEqual does not support ordering on type map") - assertError(If(Symbol("intField"), Symbol("stringField"), Symbol("stringField")), "type of predicate expression in If should be boolean") assertErrorForDifferingTypes( @@ -227,8 +214,6 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertError(operator(Seq(Symbol("booleanField"))), "requires at least two arguments") assertError(operator(Seq(Symbol("intField"), Symbol("stringField"))), "should all have the same type") - assertError(operator(Seq(Symbol("mapField"), Symbol("mapField"))), - "does not support ordering") } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala index 6f75623dc59ae..1b0850512a36c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/PredicateSuite.scala @@ -233,14 +233,6 @@ class PredicateSuite extends SparkFunSuite with ExpressionEvalHelper { StructField("a", colOneType) :: StructField("b", colTwoType) :: Nil) testWithRandomDataGeneration(structType, nullable) } - - // In doesn't support map type and will fail the analyzer. - val map = Literal.create(create_map(1 -> 1), MapType(IntegerType, IntegerType)) - In(map, Seq(map)).checkInputDataTypes() match { - case TypeCheckResult.TypeCheckFailure(msg) => - assert(msg.contains("function in does not support ordering on type map")) - case _ => fail("In should not work on map type") - } } test("switch statements in InSet for bytes, shorts, ints, dates") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 3e137d49e64c3..4d04a08ac9e04 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -835,6 +835,16 @@ class DataFrameAggregateSuite extends QueryTest df.groupBy("arr", "stru", "arrOfStru").count(), Row(Seq(0.0f, 0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, Double.NaN)), 2) ) + + // test with map type grouping columns + val df2 = Seq( + (Map("a" -> 0.0f, "b" -> -0.0f), Map(-0.0d -> Double.NaN)), + (Map("a" -> -0.0f, "b" -> 0.0f), Map(0.0d -> 0.0 / 0.0)) + ).toDF("m1", "m2") + checkAnswer( + df2.groupBy("m1", "m2").count(), + Row(Map("a" -> 0.0f, "b" -> 0.0f), Map(0.0d -> Double.NaN), 2) + ) } test("SPARK-27581: DataFrame count_distinct(\"*\") shouldn't fail with AnalysisException") { @@ -895,11 +905,10 @@ class DataFrameAggregateSuite extends QueryTest .toDF("x", "y") .select($"x", map($"x", $"y").as("y")) .createOrReplaceTempView("tempView") - val error = intercept[AnalysisException] { - sql("SELECT max_by(x, y) FROM tempView").show - } - assert( - error.message.contains("function max_by does not support ordering on type map")) + checkAnswer( + sql("SELECT max_by(x, y) FROM tempView"), + Row(2) :: Nil + ) } } @@ -951,11 +960,10 @@ class DataFrameAggregateSuite extends QueryTest .toDF("x", "y") .select($"x", map($"x", $"y").as("y")) .createOrReplaceTempView("tempView") - val error = intercept[AnalysisException] { - sql("SELECT min_by(x, y) FROM tempView").show - } - assert( - error.message.contains("function min_by does not support ordering on type map")) + checkAnswer( + sql("SELECT min_by(x, y) FROM tempView"), + Row(0) :: Nil + ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index a1c6133a24c82..cfb9bebbb93e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -3038,11 +3038,6 @@ class DataFrameFunctionsSuite extends QueryTest with SharedSparkSession { df.select(map_zip_with(col("mis"), col("i"), (x, y, z) => concat(x, y, z))) } assert(ex4a.getMessage.contains("type mismatch: argument 2 requires map type")) - - val ex5 = intercept[AnalysisException] { - df.selectExpr("map_zip_with(mmi, mmi, (x, y, z) -> x)") - } - assert(ex5.getMessage.contains("function map_zip_with does not support ordering on type map")) } test("transform keys function - primitive data types") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 797673ae15ba8..70dee3166fc06 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -347,23 +347,24 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { dates.intersect(widenTypedRows).collect() } - test("SPARK-19893: cannot run set operations with map type") { - val df = spark.range(1).select(map(lit("key"), $"id").as("m")) - val e = intercept[AnalysisException](df.intersect(df)) - assert(e.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e2 = intercept[AnalysisException](df.except(df)) - assert(e2.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - val e3 = intercept[AnalysisException](df.distinct()) - assert(e3.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - withTempView("v") { - df.createOrReplaceTempView("v") - val e4 = intercept[AnalysisException](sql("SELECT DISTINCT m FROM v")) - assert(e4.message.contains( - "Cannot have map type columns in DataFrame which calls set operations")) - } + test("SPARK-34819: set operations with map type") { + val df = spark.range(0, 2).select(map(lit("key"), $"id").as("m")) + val df2 = spark.range(1, 2).select(map(lit("key"), $"id").as("m")) + checkAnswer( + df.intersect(df2), + Row(Map("key" -> "1")) :: Nil + ) + + checkAnswer( + df.except(df2), + Row(Map("key" -> "0")) :: Nil + ) + + checkAnswer( + df.distinct(), + Row(Map("key" -> "0")) :: + Row(Map("key" -> "1")) :: Nil + ) } test("union all") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index 666bf739ca9c9..af1b0691fc6f3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -1043,6 +1043,18 @@ class DataFrameWindowFunctionsSuite extends QueryTest Seq( Row(Seq(-0.0f, 0.0f), Row(-0.0d, Double.NaN), Seq(Row(-0.0d, Double.NaN)), 2), Row(Seq(0.0f, -0.0f), Row(0.0d, Double.NaN), Seq(Row(0.0d, 0.0/0.0)), 2))) + + // test with df with map type columns. + val df3 = Seq( + (Map("a" -> 0.0f, "b" -> -0.0f), Map(-0.0d -> Double.NaN)), + (Map("a" -> -0.0f, "b" -> 0.0f), Map(0.0d -> 0.0/0.0)) + ).toDF("m1", "m2") + val windowSpec4 = Window.partitionBy("m1", "m2") + checkAnswer( + df3.select($"m1", $"m2", count(lit(1)).over(windowSpec4)), + Seq( + Row(Map("a" -> 0.0f, "b" -> -0.0f), Map(-0.0d -> Double.NaN), 2), + Row(Map("a" -> -0.0f, "b" -> 0.0f), Map(0.0d -> 0.0/0.0), 2))) } test("SPARK-34227: WindowFunctionFrame should clear its states during preparation") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3e1ad8114876a..1cc31020b09f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -979,7 +979,7 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan } test("NaN and -0.0 in join keys") { - withTempView("v1", "v2", "v3", "v4") { + withTempView("v1", "v2", "v3", "v4", "v5", "v6") { Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v1") Seq(Float.NaN -> Double.NaN, 0.0f -> 0.0, -0.0f -> -0.0).toDF("f", "d").createTempView("v2") @@ -1035,6 +1035,24 @@ class JoinSuite extends QueryTest with SharedSparkSession with AdaptiveSparkPlan Seq(0.0f, -0.0f), Row(0.0d, 0.0/0.0), Seq(Row(0.0d, 0.0/0.0))))) + + // SPARK-34819 test with tables with map type columns + Seq((Map("a" -> 0.0f, "b" -> -0.0f), Map(0.0d -> Double.NaN))) + .toDF("m1", "m2").createTempView("v5") + Seq((Map("a" -> -0.0f, "b" -> 0.0f), Map(-0.0d -> 0.0/0.0))) + .toDF("m1", "m2").createTempView("v6") + checkAnswer( + sql( + """ + |SELECT v5.m1, v5.m2, v6.m1, v6.m2 + |FROM v5 JOIN v6 + |ON v5.m1 = v6.m1 AND v5.m2 = v6.m2 + """.stripMargin), + Seq(Row( + Map("a" -> 0.0f, "b" -> -0.0f), + Map(0.0d -> Double.NaN), + Map("a" -> -0.0f, "b" -> 0.0f), + Map(-0.0d -> 0.0/0.0)))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 7e7853e1799d4..5f790dcf3170d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -22,11 +22,13 @@ import java.net.{MalformedURLException, URL} import java.sql.{Date, Timestamp} import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.mutable + import org.apache.commons.io.FileUtils import org.apache.spark.{AccumulatorSuite, SparkException} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} -import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.expressions.{CodegenObjectFactoryMode, GenericRow} import org.apache.spark.sql.catalyst.expressions.aggregate.{Complete, Partial} import org.apache.spark.sql.catalyst.optimizer.{ConvertToLocalRelation, NestedColumnAliasingSuite} import org.apache.spark.sql.catalyst.plans.logical.{LocalLimit, Project, RepartitionByExpression, Sort} @@ -511,7 +513,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row("1")) } - def sortTest(): Unit = { + def sorttest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) @@ -554,7 +556,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("external sorting") { - sortTest() + sorttest() } test("CTE feature") { @@ -4117,7 +4119,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } - + test("SPARK-33482: Fix FileScan canonicalization") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTempPath { path => @@ -4140,6 +4142,136 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } + + test("SPARK-34819: MapType supports orderable semantics") { + Seq(CodegenObjectFactoryMode.CODEGEN_ONLY.toString, + CodegenObjectFactoryMode.NO_CODEGEN.toString).foreach { + case codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode, + SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + withTable("t", "t2") { + val df = Seq( + Map("a" -> 1, "b" -> 2, "c" -> 3), + Map("c" -> 3, "b" -> 2, "a" -> 1), + Map("d" -> 4), + Map("a" -> 1, "e" -> 2), + Map("d" -> 4), + Map("d" -> 5)).toDF("m") + val df2 = Seq( + Map("b" -> 2, "a" -> 1, "c" -> 3), + ).toDF("m2") + df.createOrReplaceTempView("t") + df2.createOrReplaceTempView("t2") + + checkAnswer( + sql("select m, count(1) from t group by m"), + Row(Map("d" -> 4), 2) :: + Row(Map("d" -> 5), 1) :: + Row(Map("a" -> 1, "e" -> 2), 1) :: + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), 2) :: Nil + ) + + checkAnswer( + sql("select distinct m from t"), + Row(Map("d" -> 4)) :: + Row(Map("d" -> 5)) :: + Row(Map("a" -> 1, "e" -> 2)) :: + Row(Map("a" -> 1, "b" -> 2, "c" -> 3)) :: Nil + ) + + checkAnswer( + sql("select m from t order by m"), + Row(Map("d" -> 4)) :: + Row(Map("d" -> 4)) :: + Row(Map("d" -> 5)) :: + Row(Map("a" -> 1, "e" -> 2)) :: + Row(Map("a" -> 1, "b" -> 2, "c" -> 3)) :: + Row(Map("c" -> 3, "b" -> 2, "a" -> 1)) :: Nil + ) + + checkAnswer( + sql("select m, count(1) over (partition by m) from t"), + Row(Map("d" -> 4), 2) :: + Row(Map("d" -> 4), 2) :: + Row(Map("d" -> 5), 1) :: + Row(Map("a" -> 1, "e" -> 2), 1) :: + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), 2) :: + Row(Map("c" -> 3, "b" -> 2, "a" -> 1), 2) :: Nil + ) + + checkAnswer( + sql( + """select m2, count(1), percentile(id, 0.5) from ( + | select + | case when size(m) == 3 then m else map('b', 2, 'a', 1, 'c', 3) + | end as m2, + | 1 as id + | from t + |) + |group by m2 + |""".stripMargin), + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), 6, 1.0) + ) + + checkAnswer( + sql("select m, m2 from t join t2 on t.m = t2.m2"), + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), Map("b" -> 2, "a" -> 1, "c" -> 3)) :: + Row(Map("c" -> 3, "b" -> 2, "a" -> 1), Map("b" -> 2, "a" -> 1, "c" -> 3)) :: Nil + ) + + checkAnswer( + sql("select distinct m, m2 from t join t2 on t.m = t2.m2"), + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), Map("a" -> 1, "b" -> 2, "c" -> 3)) :: Nil + ) + + checkAnswer( + sql("select m from t where m = map('b', 2, 'a', 1, 'c', 3)"), + Row(Map('a' -> 1, 'b' -> 2, 'c' -> 3)) :: + Row(Map('c' -> 3, 'b' -> 2, 'a' -> 3)) :: Nil + ) + } + } + } + } + + test("SPARK-34819: MapType has nesting complex type supports orderable semantics") { + Seq(CodegenObjectFactoryMode.CODEGEN_ONLY.toString, + CodegenObjectFactoryMode.NO_CODEGEN.toString).foreach { + case codegenMode => + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { + withTable("t") { + val df = Seq( + Map("a" -> Map("hello" -> Array("i", "j"))), + Map("a" -> Map("world" -> Array("o", "p"))), + Map("a" -> Map("hello" -> Array("m", "n"))), + Map("a" -> Map("hello" -> Array("i", "j")))).toDF("m") + df.createOrReplaceTempView("t") + checkAnswer( + sql("select m, count(1) from t group by m"), + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("i", "j")))), 2) :: + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("m", "n")))), 1) :: + Row(Map("a" -> Map("world" -> mutable.WrappedArray.make(Array("o", "p")))), 1) :: + Nil + ) + + checkAnswer( + sql("select distinct m from t"), + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("i", "j"))))) :: + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("m", "n"))))) :: + Row(Map("a" -> Map("world" -> mutable.WrappedArray.make(Array("o", "p"))))) :: Nil + ) + + checkAnswer( + sql("select m from t order by m"), + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("i", "j"))))) :: + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("i", "j"))))) :: + Row(Map("a" -> Map("hello" -> mutable.WrappedArray.make(Array("m", "n"))))) :: + Row(Map("a" -> Map("world" -> mutable.WrappedArray.make(Array("o", "p"))))) :: Nil + ) + } + } + } + } } case class Foo(bar: Option[String]) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 0a5feda1bd533..b5e8b2e69e48c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -91,11 +91,6 @@ abstract class BucketedWriteSuite extends QueryTest with SQLTestUtils { assert(e.getMessage == "sortBy must be used together with bucketBy") } - test("sorting by non-orderable column") { - val df = Seq("a" -> Map(1 -> 1), "b" -> Map(2 -> 2)).toDF("i", "j") - intercept[AnalysisException](df.write.bucketBy(2, "i").sortBy("j").saveAsTable("tt")) - } - test("write bucketed data using save()") { val df = Seq(1 -> "a", 2 -> "b").toDF("i", "j") From ac712ceb9ee5eee4adc830989362b65106f24749 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Thu, 25 Mar 2021 13:02:22 +0800 Subject: [PATCH 2/5] update --- .../catalyst/optimizer/NormalizeMapType.scala | 33 ++----------------- .../spark/sql/execution/SparkStrategies.scala | 23 ++++++++++++- .../resources/sql-tests/results/pivot.sql.out | 12 +++---- .../sql-tests/results/udf/udf-pivot.sql.out | 12 +++---- .../org/apache/spark/sql/SQLQuerySuite.scala | 4 +-- 5 files changed, 39 insertions(+), 45 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala index 67093245a7684..4384a98c4de59 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala @@ -19,19 +19,18 @@ package org.apache.spark.sql.catalyst.optimizer import scala.math.Ordering -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, EqualTo, ExpectsInputTypes, Expression, NamedExpression, NamedLambdaVariable, TaggingExpression, UnaryExpression} +import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, ExpectsInputTypes, Expression, UnaryExpression} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator.{getValue, javaType} import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Distinct, LogicalPlan, Project, Window} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Window} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, MapData, TypeUtils} import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} -case class KeyOrderedMap(child: Expression) extends TaggingExpression - /** + * For two * Spark SQL turns grouping/join/window partition keys into binary `UnsafeRow` and compare the * binary data directly instead of using MapType's ordering. So in order to make sure two maps * have the same key value pairs but with different key ordering generate right result, we have @@ -54,42 +53,16 @@ object NormalizeMapType extends Rule[LogicalPlan] { case (l, r) => EqualTo(l, r) } ++ condition j.copy(condition = Some(newConditions.reduce(And))) - - case agg: Aggregate if agg.aggregateExpressions.exists(needNormalize) => - val replacements = agg.groupingExpressions.collect { - case e if needNormalize(e) => e - } - - agg.transformExpressionsUp { - case e => - replacements - .find(_.semanticEquals(e)) - .map(_ => normalize(e)) - .getOrElse(e) - } - - case Distinct(child) if child.output.exists(needNormalize) => - val projectList = child.output.map(normalize).asInstanceOf[Seq[NamedExpression]] - Distinct(Project(projectList, child)) } private def needNormalize(expr: Expression): Boolean = expr match { case ReorderMapKey(_) => false - case Alias(ReorderMapKey(_), _) => false case e if e.dataType.isInstanceOf[MapType] => true case _ => false } private[sql] def normalize(expr: Expression): Expression = expr match { case _ if !needNormalize(expr) => expr - case a: Attribute if a.dataType.isInstanceOf[MapType] => - val newAttr = a.withExprId(NamedExpression.newExprId) - Alias(ReorderMapKey(newAttr), a.name)(exprId = a.exprId, qualifier = a.qualifier) - case a: Alias => - a.withNewChildren(Seq(ReorderMapKey(a.child))) - case a: NamedLambdaVariable => - val newNLV = a.copy(exprId = NamedExpression.newExprId) - Alias(ReorderMapKey(newNLV), a.name)(exprId = a.exprId, qualifier = a.qualifier) case e if e.dataType.isInstanceOf[MapType] => ReorderMapKey(e) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 0e720508c02a0..dc49eed4afda0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression -import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers} +import org.apache.spark.sql.catalyst.optimizer.{BuildLeft, BuildRight, JoinSelectionHelper, NormalizeFloatingNumbers, NormalizeMapType} import org.apache.spark.sql.catalyst.planning._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ @@ -340,6 +340,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case n: NamedExpression => n case other => Alias(other, e.name)(exprId = e.exprId) } + }.map { e => + NormalizeMapType.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } } AggUtils.planStreamingAggregation( @@ -449,6 +454,11 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // Keep the name of the original expression. case other => Alias(other, e.name)(exprId = e.exprId) } + }.map { e => + NormalizeMapType.normalize(e) match { + case n: NamedExpression => n + case other => Alias(other, e.name)(exprId = e.exprId) + } } val aggregateOperator = @@ -480,6 +490,17 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { } Alias(other, name)() } + }.map { e => + NormalizeMapType.normalize(e) match { + case n: NamedExpression => n + case other => + // Keep the name of the original expression. + val name = e match { + case ne: NamedExpression => ne.name + case _ => e.toString + } + Alias(other, name)() + } } AggUtils.planAggregateWithOneDistinct( diff --git a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out index 69679f8be5fe4..52ca01ff04085 100644 --- a/sql/core/src/test/resources/sql-tests/results/pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/pivot.sql.out @@ -455,10 +455,10 @@ PIVOT ( FOR m IN (map('1', 1), map('2', 2)) ) -- !query schema -struct<> +struct 1}:bigint,{2 -> 2}:bigint> -- !query output -org.apache.spark.sql.AnalysisException -Invalid pivot column 'm#x'. Pivot columns must be comparable. +2012 35000 NULL +2013 NULL 78000 -- !query @@ -472,10 +472,10 @@ PIVOT ( FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) ) -- !query schema -struct<> +struct 1}}:bigint,{Java, {2 -> 2}}:bigint> -- !query output -org.apache.spark.sql.AnalysisException -Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable. +2012 15000 NULL +2013 NULL 30000 -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/udf/udf-pivot.sql.out b/sql/core/src/test/resources/sql-tests/results/udf/udf-pivot.sql.out index dc5cc29762657..eb2af4ec5d4eb 100644 --- a/sql/core/src/test/resources/sql-tests/results/udf/udf-pivot.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/udf/udf-pivot.sql.out @@ -421,10 +421,10 @@ PIVOT ( FOR m IN (map('1', 1), map('2', 2)) ) -- !query schema -struct<> +struct 1}:bigint,{2 -> 2}:bigint> -- !query output -org.apache.spark.sql.AnalysisException -Invalid pivot column 'm#x'. Pivot columns must be comparable. +2012 35000 NULL +2013 NULL 78000 -- !query @@ -438,10 +438,10 @@ PIVOT ( FOR (course, m) IN (('dotNET', map('1', 1)), ('Java', map('2', 2))) ) -- !query schema -struct<> +struct 1}}:bigint,{Java, {2 -> 2}}:bigint> -- !query output -org.apache.spark.sql.AnalysisException -Invalid pivot column 'named_struct(course, course#x, m, m#x)'. Pivot columns must be comparable. +2012 15000 NULL +2013 NULL 30000 -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5f790dcf3170d..f13da68924eba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4226,8 +4226,8 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer( sql("select m from t where m = map('b', 2, 'a', 1, 'c', 3)"), - Row(Map('a' -> 1, 'b' -> 2, 'c' -> 3)) :: - Row(Map('c' -> 3, 'b' -> 2, 'a' -> 3)) :: Nil + Row(Map("a" -> 1, "b" -> 2, "c" -> 3)) :: + Row(Map("c" -> 3, "b" -> 2, "a" -> 1)) :: Nil ) } } From e9366fe797f3e624ea8b41fcaa2e76a7fc5d53f7 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Thu, 25 Mar 2021 21:45:32 +0800 Subject: [PATCH 3/5] update --- .../expressions/codegen/CodeGenerator.scala | 4 +- .../catalyst/optimizer/NormalizeMapType.scala | 17 ++-- .../optimizer/NormalizeMapTypeSuite.scala | 82 +++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 13 +-- 4 files changed, 101 insertions(+), 15 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapTypeSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index bf872b29e195d..9506a59865267 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -731,8 +731,8 @@ class CodegenContext extends Logging { | | @Override | public int compare(Object a, Object b) { - | Integer indexA = (Integer)a; - | Integer indexB = (Integer)b; + | int indexA = ((Integer)a).intValue(); + | int indexB = ((Integer)b).intValue(); | ${javaType(keyType)} keyA = ${getValue("array", keyType, "indexA")}; | ${javaType(keyType)} keyB = ${getValue("array", keyType, "indexB")}; | return ${genComp(keyType, "keyA", "keyB")}; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala index 4384a98c4de59..937c10972e6a1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapType.scala @@ -30,10 +30,13 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapBuilder, MapData, TypeUt import org.apache.spark.sql.types.{AbstractDataType, DataType, MapType} /** - * For two - * Spark SQL turns grouping/join/window partition keys into binary `UnsafeRow` and compare the - * binary data directly instead of using MapType's ordering. So in order to make sure two maps - * have the same key value pairs but with different key ordering generate right result, we have + * When comparing two maps, we have to make sure two maps have the same key value pairs but + * with different key ordering are equal. + * For example, Map('a' -> 1, 'b' -> 2) equals to Map('b' -> 2, 'a' -> 1). + * + * We have to specially handle this in grouping/join/window because Spark SQL turns + * grouping/join/window partition keys into binary `UnsafeRow` and compare the + * binary data directly instead of using MapType's ordering. So in these cases, we have * to insert an expression to sort map entries by key. * * Note that, this rule must be executed at the end of optimizer, because the optimizer may create @@ -56,7 +59,7 @@ object NormalizeMapType extends Rule[LogicalPlan] { } private def needNormalize(expr: Expression): Boolean = expr match { - case ReorderMapKey(_) => false + case SortMapKey(_) => false case e if e.dataType.isInstanceOf[MapType] => true case _ => false } @@ -64,11 +67,11 @@ object NormalizeMapType extends Rule[LogicalPlan] { private[sql] def normalize(expr: Expression): Expression = expr match { case _ if !needNormalize(expr) => expr case e if e.dataType.isInstanceOf[MapType] => - ReorderMapKey(e) + SortMapKey(e) } } -case class ReorderMapKey(child: Expression) extends UnaryExpression with ExpectsInputTypes { +case class SortMapKey(child: Expression) extends UnaryExpression with ExpectsInputTypes { private lazy val MapType(keyType, valueType, valueContainsNull) = dataType.asInstanceOf[MapType] private lazy val keyOrdering: Ordering[Any] = TypeUtils.getInterpretedOrdering(keyType) private lazy val mapBuilder = new ArrayBasedMapBuilder(keyType, valueType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapTypeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapTypeSuite.scala new file mode 100644 index 0000000000000..427e25ab8877e --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeMapTypeSuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.types.{MapType, StringType} + +class NormalizeMapTypeSuite extends PlanTest { + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = Batch("NormalizeMapType", Once, NormalizeMapType) :: Nil + } + + val testRelation1 = LocalRelation('a.int, 'm.map(MapType(StringType, StringType, false))) + val a1 = testRelation1.output(0) + val m1 = testRelation1.output(1) + + val testRelation2 = LocalRelation('a.int, 'm.map(MapType(StringType, StringType, false))) + val a2 = testRelation2.output(0) + val m2 = testRelation2.output(1) + + test("normalize map types in window function expressions") { + val query = testRelation1.window(Seq(sum(a1).as("sum")), Seq(m1), Seq(m1.asc)) + val optimized = Optimize.execute(query) + val correctAnswer = testRelation1.window(Seq(sum(a1).as("sum")), + Seq(SortMapKey(m1)), Seq(m1.asc)) + + comparePlans(optimized, correctAnswer) + } + + test("normalize map types in window function expressions - idempotence") { + val query = testRelation1.window(Seq(sum(a1).as("sum")), Seq(m1), Seq(m1.asc)) + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val correctAnswer = testRelation1.window(Seq(sum(a1).as("sum")), + Seq(SortMapKey(m1)), Seq(m1.asc)) + + comparePlans(doubleOptimized, correctAnswer) + } + + test("normalize map types in join keys") { + val query = testRelation1.join(testRelation2, condition = Some(m1 === m2)) + + val optimized = Optimize.execute(query) + val joinCond = Some(SortMapKey(m1) === SortMapKey(m2)) + val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) + + comparePlans(optimized, correctAnswer) + } + + test("normalize map types in join keys - idempotence") { + val query = testRelation1.join(testRelation2, condition = Some(m1 === m2)) + + val optimized = Optimize.execute(query) + val doubleOptimized = Optimize.execute(optimized) + val joinCond = Some(SortMapKey(m1) === SortMapKey(m2)) + val correctAnswer = testRelation1.join(testRelation2, condition = joinCond) + + comparePlans(doubleOptimized, correctAnswer) + } +} + + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index dc49eed4afda0..7c7a32faaa09d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -333,8 +333,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val stateVersion = conf.getConf(SQLConf.STREAMING_AGGREGATION_STATE_FORMAT_VERSION) - // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because - // `groupingExpressions` is not extracted during logical phase. + // Ideally this should be done in `NormalizeFloatingNumbers` and `NormalizeMapType`, + // but we do it here because `groupingExpressions` is not extracted during logical phase. val normalizedGroupingExpressions = namedGroupingExpressions.map { e => NormalizeFloatingNumbers.normalize(e) match { case n: NamedExpression => n @@ -446,8 +446,8 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "Spark user mailing list.") } - // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here because - // `groupingExpressions` is not extracted during logical phase. + // Ideally this should be done in `NormalizeFloatingNumbers` and `NormalizeMapType`, + // but we do it here because `groupingExpressions` is not extracted during logical phase. val normalizedGroupingExpressions = groupingExpressions.map { e => NormalizeFloatingNumbers.normalize(e) match { case n: NamedExpression => n @@ -478,8 +478,9 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val distinctExpressions = functionsWithDistinct.head.aggregateFunction.children.filterNot(_.foldable) val normalizedNamedDistinctExpressions = distinctExpressions.map { e => - // Ideally this should be done in `NormalizeFloatingNumbers`, but we do it here - // because `distinctExpressions` is not extracted during logical phase. + // Ideally this should be done in `NormalizeFloatingNumbers` and `NormalizeMapType`, + // but we do it here because `distinctExpressions` is not extracted during + // logical phase. NormalizeFloatingNumbers.normalize(e) match { case ne: NamedExpression => ne case other => From a08356ae30a94e7422b99ab5ee3b20d4ce2d9554 Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 26 Mar 2021 15:22:08 +0800 Subject: [PATCH 4/5] fix style --- .../org/apache/spark/sql/SQLQuerySuite.scala | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index f13da68924eba..51ad5ab4457af 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -513,7 +513,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Row("1")) } - def sorttest(): Unit = { + def sortTest(): Unit = { checkAnswer( sql("SELECT * FROM testData2 ORDER BY a ASC, b ASC"), Seq(Row(1, 1), Row(1, 2), Row(2, 1), Row(2, 2), Row(3, 1), Row(3, 2))) @@ -556,7 +556,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } test("external sorting") { - sorttest() + sortTest() } test("CTE feature") { @@ -4119,7 +4119,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark } } } - + test("SPARK-33482: Fix FileScan canonicalization") { withSQLConf(SQLConf.USE_V1_SOURCE_LIST.key -> "") { withTempPath { path => @@ -4147,8 +4147,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Seq(CodegenObjectFactoryMode.CODEGEN_ONLY.toString, CodegenObjectFactoryMode.NO_CODEGEN.toString).foreach { case codegenMode => - withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode, - SQLConf.USE_OBJECT_HASH_AGG.key -> "false") { + withSQLConf(SQLConf.CODEGEN_FACTORY_MODE.key -> codegenMode) { withTable("t", "t2") { val df = Seq( Map("a" -> 1, "b" -> 2, "c" -> 3), @@ -4156,9 +4155,10 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark Map("d" -> 4), Map("a" -> 1, "e" -> 2), Map("d" -> 4), - Map("d" -> 5)).toDF("m") + Map("d" -> 5) + ).toDF("m") val df2 = Seq( - Map("b" -> 2, "a" -> 1, "c" -> 3), + Map("b" -> 2, "a" -> 1, "c" -> 3) ).toDF("m2") df.createOrReplaceTempView("t") df2.createOrReplaceTempView("t2") @@ -4210,7 +4210,7 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark |) |group by m2 |""".stripMargin), - Row(Map("a" -> 1, "b" -> 2, "c" -> 3), 6, 1.0) + Row(Map("a" -> 1, "b" -> 2, "c" -> 3), 6, 1.0) :: Nil ) checkAnswer( From 58bb3cc5046563f18a7c8243d8b5b6c1d86ac5de Mon Sep 17 00:00:00 2001 From: "wangguangxin.cn" Date: Fri, 26 Mar 2021 18:21:24 +0800 Subject: [PATCH 5/5] fix ut --- .../expressions/codegen/CodeGenerator.scala | 19 +++++++----- .../analysis/AnalysisErrorSuite.scala | 31 ++++++++----------- .../ExpressionTypeCheckingSuite.scala | 4 +-- .../sql/DataFrameSetOperationsSuite.scala | 12 +++---- 4 files changed, 33 insertions(+), 33 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala index 9506a59865267..d1f36a98831b1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodeGenerator.scala @@ -694,20 +694,28 @@ class CodegenContext extends Logging { val keyIndexComparator = freshName("keyIndexComparator") val compareKeyFunc = freshName("compareKey") val compareValueFunc = freshName("compareValue") - val nullValueCheck = if (valueContainsNull) { + val nullSafeCompare = + s""" + |${javaType(valueType)} left = ${getValue("leftArray", valueType, "leftIndex")}; + |${javaType(valueType)} right = ${getValue("rightArray", valueType, "rightIndex")}; + |return ${genComp(valueType, "left", "right")}; + |""".stripMargin + val compareElement = if (valueContainsNull) { s""" |boolean isNullA = leftArray.isNullAt(leftIndex); |boolean isNullB = rightArray.isNullAt(rightIndex); |if (isNullA && isNullB) { - | // do nothing + | return 0; |} else if (isNullA) { | return -1; |} else if (isNullB) { | return 1; + |} else { + | $nullSafeCompare |} |""".stripMargin } else { - "" + nullSafeCompare } addNewFunction(initIndexArrayFunc, @@ -753,10 +761,7 @@ class CodegenContext extends Logging { s""" |private int $compareValueFunc(ArrayData leftArray, int leftIndex, ArrayData rightArray, | int rightIndex) { - | $nullValueCheck - | ${javaType(valueType)} left = ${getValue("leftArray", valueType, "leftIndex")}; - | ${javaType(valueType)} right = ${getValue("rightArray", valueType, "rightIndex")}; - | return ${genComp(valueType, "left", "right")}; + | $compareElement |} |""".stripMargin) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala index dec9e8a0da9cb..c84de21f3dddc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala @@ -52,34 +52,34 @@ private[sql] class GroupableUDT extends UserDefinedType[GroupableData] { private[spark] override def asNullable: GroupableUDT = this } -private[sql] case class UngroupableData(data: Map[Int, Int]) { +private[sql] case class GroupableData2(data: Map[Int, Int]) { def getData: Map[Int, Int] = data } -private[sql] class UngroupableUDT extends UserDefinedType[UngroupableData] { +private[sql] class GroupableUDT2 extends UserDefinedType[GroupableData2] { override def sqlType: DataType = MapType(IntegerType, IntegerType) - override def serialize(ungroupableData: UngroupableData): MapData = { - val keyArray = new GenericArrayData(ungroupableData.data.keys.toSeq) - val valueArray = new GenericArrayData(ungroupableData.data.values.toSeq) + override def serialize(groupableData: GroupableData2): MapData = { + val keyArray = new GenericArrayData(groupableData.data.keys.toSeq) + val valueArray = new GenericArrayData(groupableData.data.values.toSeq) new ArrayBasedMapData(keyArray, valueArray) } - override def deserialize(datum: Any): UngroupableData = { + override def deserialize(datum: Any): GroupableData2 = { datum match { case data: MapData => val keyArray = data.keyArray().array val valueArray = data.valueArray().array assert(keyArray.length == valueArray.length) val mapData = keyArray.zip(valueArray).toMap.asInstanceOf[Map[Int, Int]] - UngroupableData(mapData) + GroupableData2(mapData) } } - override def userClass: Class[UngroupableData] = classOf[UngroupableData] + override def userClass: Class[GroupableData2] = classOf[GroupableData2] - private[spark] override def asNullable: UngroupableUDT = this + private[spark] override def asNullable: GroupableUDT2 = this } case class TestFunction( @@ -587,19 +587,14 @@ class AnalysisErrorSuite extends AnalysisTest { new StructType() .add("f1", FloatType, nullable = true) .add("f2", ArrayType(BooleanType, containsNull = true), nullable = true), - new GroupableUDT()) - supportedDataTypes.foreach { dataType => - checkDataType(dataType, shouldSuccess = true) - } - - val unsupportedDataTypes = Seq( MapType(StringType, LongType), new StructType() .add("f1", FloatType, nullable = true) .add("f2", MapType(StringType, LongType), nullable = true), - new UngroupableUDT()) - unsupportedDataTypes.foreach { dataType => - checkDataType(dataType, shouldSuccess = false) + new GroupableUDT(), + new GroupableUDT2()) + supportedDataTypes.foreach { dataType => + checkDataType(dataType, shouldSuccess = true) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala index ce49622a3234f..2ac8fd3c9aac6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/ExpressionTypeCheckingSuite.scala @@ -143,8 +143,8 @@ class ExpressionTypeCheckingSuite extends SparkFunSuite { assertSuccess(new BoolAnd(Symbol("booleanField"))) assertSuccess(new BoolOr(Symbol("booleanField"))) - assertError(Min(Symbol("mapField")), "min does not support ordering on type") - assertError(Max(Symbol("mapField")), "max does not support ordering on type") + assertSuccess(Min(Symbol("mapField"))) + assertSuccess(Max(Symbol("mapField"))) assertError(Sum(Symbol("booleanField")), "function sum requires numeric type") assertError(Average(Symbol("booleanField")), "function average requires numeric type") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala index 70dee3166fc06..376ea07ad513d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSetOperationsSuite.scala @@ -348,22 +348,22 @@ class DataFrameSetOperationsSuite extends QueryTest with SharedSparkSession { } test("SPARK-34819: set operations with map type") { - val df = spark.range(0, 2).select(map(lit("key"), $"id").as("m")) - val df2 = spark.range(1, 2).select(map(lit("key"), $"id").as("m")) + val df = Seq(Map("a" -> 1, "b" -> 2), Map("c" -> 3)).toDF("m") + val df2 = Seq(Map("b" -> 2, "a" -> 1), Map("c" -> 4)).toDF("m") checkAnswer( df.intersect(df2), - Row(Map("key" -> "1")) :: Nil + Row(Map("a" -> 1, "b" -> 2)) :: Nil ) checkAnswer( df.except(df2), - Row(Map("key" -> "0")) :: Nil + Row(Map("c" -> 3)) :: Nil ) checkAnswer( df.distinct(), - Row(Map("key" -> "0")) :: - Row(Map("key" -> "1")) :: Nil + Row(Map("a" -> 1, "b" -> 2)) :: + Row(Map("c" -> 3)) :: Nil ) }