diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index fcaf2b1d9d301..3786643125a9f 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -631,6 +631,11 @@ "Cannot process input data types for the expression: ." ], "subClass" : { + "BAD_INPUTS" : { + "message" : [ + "The input data types to must be valid, but found the input types ." + ] + }, "MISMATCHED_TYPES" : { "message" : [ "All input types must be the same except nullable, containsNull, valueContainsNull flags, but found the input types ." @@ -1011,6 +1016,11 @@ "The input of can't be type data." ] }, + "UNSUPPORTED_MODE_DATA_TYPE" : { + "message" : [ + "The does not support the data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields." + ] + }, "UNSUPPORTED_UDF_INPUT_TYPE" : { "message" : [ "UDFs do not support '' as an input data type." diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index e254a670991a1..8998348f0571b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -17,14 +17,17 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup} import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} +import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType -import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils} +import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils} +import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType} import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.OpenHashMap @@ -50,17 +53,20 @@ case class Mode( override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType) override def checkInputDataTypes(): TypeCheckResult = { - if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) { + // TODO: SPARK-49358: Mode expression for map type with collated fields + if (UnsafeRowUtils.isBinaryStable(child.dataType) || + !child.dataType.existsRecursively(f => f.isInstanceOf[MapType] && + !UnsafeRowUtils.isBinaryStable(f))) { /* * The Mode class uses collation awareness logic to handle string data. - * Complex types with collated fields are not yet supported. + * All complex types except MapType with collated fields are supported. */ - // TODO: SPARK-48700: Mode expression for complex types (all collations) super.checkInputDataTypes() } else { - TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" + - " a type of binary-unstable type that is " + - s"not currently supported by ${prettyName}.") + TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE", + messageParameters = + Map("child" -> toSQLType(child.dataType), + "mode" -> toSQLId(prettyName))) } } @@ -86,6 +92,54 @@ case class Mode( buffer } + private def getCollationAwareBuffer( + childDataType: DataType, + buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = { + def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, Long)] = { + buffer.groupMapReduce(t => + groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values + } + def determineBufferingFunction( + childDataType: DataType): Option[AnyRef => _] = { + childDataType match { + case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None + case _ => Some(collationAwareTransform(_, childDataType)) + } + } + determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer) + } + + protected[sql] def collationAwareTransform(data: AnyRef, dataType: DataType): AnyRef = { + dataType match { + case _ if UnsafeRowUtils.isBinaryStable(dataType) => data + case st: StructType => + processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields)) + case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData]) + case st: StringType => + CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId) + case _ => + throw new SparkIllegalArgumentException( + errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS", + messageParameters = Map( + "expression" -> toSQLExpr(this), + "functionName" -> toSQLType(prettyName), + "dataType" -> toSQLType(child.dataType)) + ) + } + } + + private def processStructTypeWithBuffer( + tuples: Seq[(Any, StructField)]): Seq[Any] = { + tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], t._2.dataType)) + } + + private def processArrayTypeWithBuffer( + a: ArrayType, + data: ArrayData): Seq[Any] = { + (0 until data.numElements()).map(i => + collationAwareTransform(data.get(i, a.elementType), a.elementType)) + } + override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = { if (buffer.isEmpty) { return null @@ -102,17 +156,12 @@ case class Mode( * to a single value (the sum of the counts), and finally reduces the groups to a single map. * * The new map is then used in the rest of the Mode evaluation logic. + * + * It is expected to work for all simple and complex types with + * collated fields, except for MapType (temporarily). */ - val collationAwareBuffer = child.dataType match { - case c: StringType if - !CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality => - val collationId = c.collationId - val modeMap = buffer.toSeq.groupMapReduce { - case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId) - }(x => x)((x, y) => (x._1, x._2 + y._2)).values - modeMap - case _ => buffer - } + val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer) + reverseOpt.map { reverse => val defaultKeyOrdering = if (reverse) { PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala index 941d5cd31db40..9930709cd8bf3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpressionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import java.sql.{Date, Timestamp} import java.text.SimpleDateFormat +import java.util.Locale import scala.collection.immutable.Seq -import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException} -import org.apache.spark.sql.catalyst.ExtendedAnalysisException +import org.apache.spark.{SparkConf, SparkException, SparkIllegalArgumentException, SparkRuntimeException, SparkThrowable} +import org.apache.spark.sql.catalyst.{ExtendedAnalysisException, InternalRow} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.Mode import org.apache.spark.sql.internal.{SqlApiConf, SQLConf} @@ -1752,7 +1753,7 @@ class CollationSQLExpressionsSuite UTF8StringModeTestCase("unicode_ci", bufferValuesUTF8String, "b"), UTF8StringModeTestCase("unicode", bufferValuesUTF8String, "a")) - testCasesUTF8String.foreach(t => { + testCasesUTF8String.foreach ( t => { val buffer = new OpenHashMap[AnyRef, Long](5) val myMode = Mode(child = Literal.create("some_column_name", StringType(t.collationId))) t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } @@ -1760,6 +1761,40 @@ class CollationSQLExpressionsSuite }) } + test("Support Mode.eval(buffer) with complex types") { + case class UTF8StringModeTestCase[R]( + collationId: String, + bufferValues: Map[InternalRow, Long], + result: R) + + val bufferValuesUTF8String: Map[Any, Long] = Map( + UTF8String.fromString("a") -> 5L, + UTF8String.fromString("b") -> 4L, + UTF8String.fromString("B") -> 3L, + UTF8String.fromString("d") -> 2L, + UTF8String.fromString("e") -> 1L) + + val bufferValuesComplex = bufferValuesUTF8String.map{ + case (k, v) => (InternalRow.fromSeq(Seq(k, k, k)), v) + } + val testCasesUTF8String = Seq( + UTF8StringModeTestCase("utf8_binary", bufferValuesComplex, "[a,a,a]"), + UTF8StringModeTestCase("UTF8_LCASE", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode_ci", bufferValuesComplex, "[b,b,b]"), + UTF8StringModeTestCase("unicode", bufferValuesComplex, "[a,a,a]")) + + testCasesUTF8String.foreach { t => + val buffer = new OpenHashMap[AnyRef, Long](5) + val myMode = Mode(child = Literal.create(null, StructType(Seq( + StructField("f1", StringType(t.collationId), true), + StructField("f2", StringType(t.collationId), true), + StructField("f3", StringType(t.collationId), true) + )))) + t.bufferValues.foreach { case (k, v) => buffer.update(k, v) } + assert(myMode.eval(buffer).toString.toLowerCase() == t.result.toLowerCase()) + } + } + test("Support mode for string expression with collated strings in struct") { case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) val testCases = Seq( @@ -1780,33 +1815,7 @@ class CollationSQLExpressionsSuite t.collationId + ", f2: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode'" + - " was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } }) } @@ -1819,47 +1828,21 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"named_struct('f1', " + s"named_struct('f2', collate('$elt', '${t.collationId}')), 'f3', 1)").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_nested_struct1" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(i STRUCT, f3: INT>) USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(mode(i).f1.f2) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || - t.collationId == "unicode") { - // Cannot resolve "mode(i)" due to data type mismatch: - // Input to function mode was a complex type with strings collated on non-binary - // collations, which is not yet supported.. SQLSTATE: 42K09; line 1 pos 13; - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' " + - "was a type of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 13, - stopIndex = 19, - fragment = "mode(i)") - ) - ) - } else { - checkAnswer(sql(query), Row(t.result)) - } + checkAnswer(sql(query), Row(t.result)) } - }) + } } test("Support mode for string expression with collated strings in array complex type") { @@ -1870,44 +1853,150 @@ class CollationSQLExpressionsSuite ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") ) - testCases.foreach(t => { + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => s"array(named_struct('f2', " + + s"collate('$elt', '${t.collationId}'), 'f3', 1))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_struct2" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i ARRAY< STRUCT>)" + + s" USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(element_at(mode(i).f2, 1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated strings in 3D array type") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => + val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => + (0L to numRepeats).map(_ => + s"array(array(array(collate('$elt', '${t.collationId}'))))").mkString(",") + }.mkString(",") + + val tableName = s"t_${t.collationId}_mode_nested_3d_array" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(i ARRAY>>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) + val query = s"SELECT lower(" + + s"element_at(element_at(element_at(mode(i),1),1),1)) FROM ${tableName}" + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode for string expression with collated complex type - Highly nested") { + case class ModeTestCase[R](collationId: String, bufferValues: Map[String, Long], result: R) + val testCases = Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("UTF8_LCASE", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ) + testCases.foreach { t => val valuesToAdd = t.bufferValues.map { case (elt, numRepeats) => (0L to numRepeats).map(_ => s"array(named_struct('s1', named_struct('a2', " + s"array(collate('$elt', '${t.collationId}'))), 'f3', 1))").mkString(",") }.mkString(",") - val tableName = s"t_${t.collationId}_mode_nested_struct" + val tableName = s"t_${t.collationId}_mode_highly_nested_struct" withTable(tableName) { sql(s"CREATE TABLE ${tableName}(" + s"i ARRAY>, f3: INT>>)" + s" USING parquet") sql(s"INSERT INTO ${tableName} VALUES " + valuesToAdd) val query = s"SELECT lower(element_at(element_at(mode(i), 1).s1.a2, 1)) FROM ${tableName}" - if(t.collationId == "UTF8_LCASE" || - t.collationId == "unicode_ci" || t.collationId == "unicode") { - val params = Seq(("sqlExpr", "\"mode(i)\""), - ("msg", "The input to the function 'mode' was a type" + - " of binary-unstable type that is not currently supported by mode."), - ("hint", "")).toMap - checkError( - exception = intercept[AnalysisException] { - sql(query) - }, - condition = "DATATYPE_MISMATCH.TYPE_CHECK_FAILURE_WITH_HINT", - parameters = params, - queryContext = Array( - ExpectedContext(objectType = "", - objectName = "", - startIndex = 35, - stopIndex = 41, - fragment = "mode(i)") - ) - ) - } else { + checkAnswer(sql(query), Row(t.result)) + } + } + } + + test("Support mode expression with collated in recursively nested struct with map with keys") { + case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) + Seq( + ModeTestCase("utf8_binary", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{a -> 1}"), + ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "{b -> 1}") + ).foreach { t1 => + def checkThisError(t: ModeTestCase, query: String): Any = { + val c = s"STRUCT>" + val c1 = s"\"${c}\"" + checkError( + exception = intercept[SparkThrowable] { + sql(query).collect() + }, + condition = "DATATYPE_MISMATCH.UNSUPPORTED_MODE_DATA_TYPE", + parameters = Map( + ("sqlExpr", "\"mode(i)\""), + ("child", c1), + ("mode", "`mode`")), + queryContext = Seq(ExpectedContext("mode(i)", 18, 24)).toArray + ) + } + + def getValuesToAdd(t: ModeTestCase): String = { + val valuesToAdd = t.bufferValues.map { + case (elt, numRepeats) => + (0L to numRepeats).map(i => + s"named_struct('m1', map(collate('$elt', '${t.collationId}'), 1))" + ).mkString(",") + }.mkString(",") + valuesToAdd + } + val tableName = s"t_${t1.collationId}_mode_nested_map_struct1" + withTable(tableName) { + sql(s"CREATE TABLE ${tableName}(" + + s"i STRUCT>) USING parquet") + sql(s"INSERT INTO ${tableName} VALUES ${getValuesToAdd(t1)}") + val query = "SELECT lower(cast(mode(i).m1 as string))" + + s" FROM ${tableName}" + if (t1.collationId == "utf8_binary") { + checkAnswer(sql(query), Row(t1.result)) + } else { + checkThisError(t1, query) } } - }) + } + } + + test("UDT with collation - Mode (throw exception)") { + case class ModeTestCase(collationId: String, bufferValues: Map[String, Long], result: String) + Seq( + ModeTestCase("utf8_lcase", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b"), + ModeTestCase("unicode", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "a"), + ModeTestCase("unicode_ci", Map("a" -> 3L, "b" -> 2L, "B" -> 2L), "b") + ).foreach { t1 => + checkError( + exception = intercept[SparkIllegalArgumentException] { + Mode( + child = Literal.create(null, + MapType(StringType(t1.collationId), IntegerType)) + ).collationAwareTransform( + data = Map.empty[String, Any], + dataType = MapType(StringType(t1.collationId), IntegerType) + ) + }, + condition = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS", + parameters = Map( + "expression" -> "\"mode(NULL)\"", + "functionName" -> "\"MODE\"", + "dataType" -> s"\"MAP\"") + ) + } } test("SPARK-48430: Map value extraction with collations") {