Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-48700][SQL] Mode expression for complex types (all collations) #47154

Closed
wants to merge 27 commits into from
Closed
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a909cb1
reapply changes
GideonPotok Sep 11, 2024
ce865a9
SPARK COLLATIONS MAP
GideonPotok Sep 11, 2024
d697403
formatting
GideonPotok Sep 13, 2024
535b16b
Revert "formatting"
GideonPotok Sep 13, 2024
066ebd4
formatting
GideonPotok Sep 13, 2024
f2d0503
formatting
GideonPotok Sep 13, 2024
4f0cfbe
formatting
GideonPotok Sep 13, 2024
432de23
move reason
GideonPotok Sep 17, 2024
a8d626b
four spaces for classes
GideonPotok Sep 20, 2024
d621b8a
fix indentation of method params
GideonPotok Sep 20, 2024
af97fe8
fix indentation of method params
GideonPotok Sep 20, 2024
bf91fe9
fix indentation of method params
GideonPotok Sep 20, 2024
0b7364f
fix indentation of method params
GideonPotok Sep 20, 2024
ca564d3
fix indentation of method params
GideonPotok Sep 20, 2024
96c742f
Update common/utils/src/main/resources/error/error-conditions.json
GideonPotok Sep 20, 2024
d5552cd
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
ce8986f
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
2632b91
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
72483ac
Update sql/core/src/test/scala/org/apache/spark/sql/CollationSQLExpre…
GideonPotok Sep 20, 2024
4695462
fix call to throw SparkUnsupportedOperationException
GideonPotok Sep 20, 2024
b285a6f
Apply suggestions from code review
GideonPotok Sep 20, 2024
e330698
fix
GideonPotok Sep 24, 2024
f4074be
hello
GideonPotok Sep 26, 2024
adae8f3
passing tests
GideonPotok Sep 28, 2024
37efd0c
passing tests
GideonPotok Sep 28, 2024
afd123b
Added COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.BAD_INPUTS. Tests pass.
GideonPotok Sep 29, 2024
f4c39b1
reformat error-conditions.json for test 'Error conditions are correct…
GideonPotok Sep 30, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,11 @@
"The input of <functionName> can't be <dataType> type data."
]
},
"UNSUPPORTED_MODE_DATA_TYPE" : {
"message" : [
"The <mode> does not support the <child> data type, because there is a \"MAP\" type with keys and/or values that have collated sub-fields."
]
},
"UNSUPPORTED_UDF_INPUT_TYPE" : {
"message" : [
"UDFs do not support '<dataType>' as an input data type."
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,16 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, TypeCheckResult, UnresolvedWithinGroup}
import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder}
import org.apache.spark.sql.catalyst.trees.UnaryLike
import org.apache.spark.sql.catalyst.types.PhysicalDataType
import org.apache.spark.sql.catalyst.util.{CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, UnsafeRowUtils}
import org.apache.spark.sql.errors.DataTypeErrors.{toSQLId, toSQLType}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, StringType}
import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.collection.OpenHashMap

Expand All @@ -50,17 +52,20 @@ case class Mode(
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType)

override def checkInputDataTypes(): TypeCheckResult = {
if (UnsafeRowUtils.isBinaryStable(child.dataType) || child.dataType.isInstanceOf[StringType]) {
// TODO: SPARK-49358: Mode expression for map type with collated fields
if (UnsafeRowUtils.isBinaryStable(child.dataType) ||
!child.dataType.existsRecursively(f => f.isInstanceOf[MapType] &&
!UnsafeRowUtils.isBinaryStable(f))) {
GideonPotok marked this conversation as resolved.
Show resolved Hide resolved
/*
* The Mode class uses collation awareness logic to handle string data.
* Complex types with collated fields are not yet supported.
* All complex types except MapType with collated fields are supported.
*/
// TODO: SPARK-48700: Mode expression for complex types (all collations)
super.checkInputDataTypes()
} else {
TypeCheckResult.TypeCheckFailure("The input to the function 'mode' was" +
" a type of binary-unstable type that is " +
s"not currently supported by ${prettyName}.")
TypeCheckResult.DataTypeMismatch("UNSUPPORTED_MODE_DATA_TYPE",
messageParameters =
Map("child" -> toSQLType(child.dataType),
"mode" -> toSQLId(prettyName)))
}
}

Expand All @@ -86,6 +91,53 @@ case class Mode(
buffer
}

private def getCollationAwareBuffer(
childDataType: DataType,
buffer: OpenHashMap[AnyRef, Long]): Iterable[(AnyRef, Long)] = {
def groupAndReduceBuffer(groupingFunction: AnyRef => _): Iterable[(AnyRef, Long)] = {
buffer.groupMapReduce(t =>
groupingFunction(t._1))(x => x)((x, y) => (x._1, x._2 + y._2)).values
}
def determineBufferingFunction(
childDataType: DataType): Option[AnyRef => _] = {
childDataType match {
case _ if UnsafeRowUtils.isBinaryStable(child.dataType) => None
case _ => Some(collationAwareTransform(_, childDataType))
}
}
determineBufferingFunction(childDataType).map(groupAndReduceBuffer).getOrElse(buffer)
}

private def collationAwareTransform(data: AnyRef, dataType: DataType): AnyRef = {
dataType match {
case _ if UnsafeRowUtils.isBinaryStable(dataType) => data
case st: StructType =>
processStructTypeWithBuffer(data.asInstanceOf[InternalRow].toSeq(st).zip(st.fields))
case at: ArrayType => processArrayTypeWithBuffer(at, data.asInstanceOf[ArrayData])
case st: StringType =>
CollationFactory.getCollationKey(data.asInstanceOf[UTF8String], st.collationId)
case _ =>
throw new SparkUnsupportedOperationException(
"UNSUPPORTED_MODE_DATA_TYPE",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

UNSUPPORTED_MODE_DATA_TYPE is a sub-condition of DATATYPE_MISMATCH. You cannot refers to it in this way. Could you add a test for this case, please.

BTW, if you cannot reproduce the error via public API, we should consider to convert it to an internal error.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk For the test, I suppose the situation where a type is not binary stable and is not covered by checkInputDataTypes would include some UDTs? checkInputType just confirms it is not a MapType (a blacklist/blocklist), whereas collationAwareTransform is an allowlist/whitelist).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk @uros-db I am having a lot of trouble with this one! I implemented it as COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUT, and plan to create a different subclass that represents this situation (Maybe BAD_INPUT), once I just have it working. But When I throw the following:

SparkUnsupportedOperationException(
          errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUT",

I end up getting:

// org.apache.spark.SparkException: [INTERNAL_ERROR] Cannot
        // find sub error class 'COMPLEX_EXPRESSION_UNSUPPORTED_INPUT.NO_INPUT' SQLSTATE: XX000

Yet if I do

SparkUnsupportedOperationException(
          errorClass = "COMPLEX_EXPRESSION_UNSUPPORTED_INPUT",

Then org.apache.spark.ErrorClassesJsonReader#getMessageTemplate fails during the assertion assert(errorInfo.subClass.isDefined == subErrorClass.isDefined)

As the subClass is missing.

Would either of you be able to tell me if this is the right pattern (eg. maybe it should be ComplexExpressionException("BAD_INPUT"), not sure what is the "old" pattern and which is the preferred one).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@MaxGekk please re-review

messageParameters =
Map("child" -> toSQLType(child.dataType),
"mode" -> toSQLId(prettyName))
)
}
}

private def processStructTypeWithBuffer(
tuples: Seq[(Any, StructField)]): Seq[Any] = {
tuples.map(t => collationAwareTransform(t._1.asInstanceOf[AnyRef], t._2.dataType))
}

private def processArrayTypeWithBuffer(
a: ArrayType,
data: ArrayData): Seq[Any] = {
(0 until data.numElements()).map(i =>
collationAwareTransform(data.get(i, a.elementType), a.elementType))
}

GideonPotok marked this conversation as resolved.
Show resolved Hide resolved
override def eval(buffer: OpenHashMap[AnyRef, Long]): Any = {
if (buffer.isEmpty) {
return null
Expand All @@ -102,17 +154,12 @@ case class Mode(
* to a single value (the sum of the counts), and finally reduces the groups to a single map.
*
* The new map is then used in the rest of the Mode evaluation logic.
*
* It is expected to work for all simple and complex types with
* collated fields, except for MapType (temporarily).
*/
val collationAwareBuffer = child.dataType match {
case c: StringType if
!CollationFactory.fetchCollation(c.collationId).supportsBinaryEquality =>
val collationId = c.collationId
val modeMap = buffer.toSeq.groupMapReduce {
case (k, _) => CollationFactory.getCollationKey(k.asInstanceOf[UTF8String], collationId)
}(x => x)((x, y) => (x._1, x._2 + y._2)).values
modeMap
case _ => buffer
}
val collationAwareBuffer = getCollationAwareBuffer(child.dataType, buffer)
GideonPotok marked this conversation as resolved.
Show resolved Hide resolved

reverseOpt.map { reverse =>
val defaultKeyOrdering = if (reverse) {
PhysicalDataType.ordering(child.dataType).asInstanceOf[Ordering[AnyRef]].reverse
Expand Down
Loading