diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala index c7a8bc74f4ddc..b6e87c456de0c 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/util/UnsafeRowUtilsSuite.scala @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow} -import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, IntegerType, MapType, StringType, StructField, StructType} class UnsafeRowUtilsSuite extends SparkFunSuite { @@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite { "fieldStatus:\n" + "[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,")) } + + test("isBinaryStable on complex types containing collated strings") { + val nonBinaryStringType = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")) + + // simple checks + assert(UnsafeRowUtils.isBinaryStable(IntegerType)) + assert(UnsafeRowUtils.isBinaryStable(StringType)) + assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType)) + + assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType))) + assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType))) + + assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, StringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, nonBinaryStringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, nonBinaryStringType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, nonBinaryStringType))) + + assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", IntegerType) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", StringType) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", nonBinaryStringType) :: Nil))) + + // nested complex types + assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType)))) + assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, IntegerType)))) + assert(UnsafeRowUtils.isBinaryStable( + ArrayType(StructType(StructField("field", StringType) :: Nil)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable( + ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType))))) + assert(!UnsafeRowUtils.isBinaryStable( + ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil)))) + assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType( + Seq(StructField("second", IntegerType), StructField("second", nonBinaryStringType)))))) + + assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), ArrayType(IntegerType)))) + assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, StringType), IntegerType))) + assert(UnsafeRowUtils.isBinaryStable( + MapType(StructType(StructField("field", StringType) :: Nil), IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType)))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(IntegerType, ArrayType(nonBinaryStringType)))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(MapType(IntegerType, nonBinaryStringType), IntegerType))) + assert(!UnsafeRowUtils.isBinaryStable( + MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), IntegerType))) + + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", ArrayType(IntegerType)) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", MapType(StringType, IntegerType)) :: Nil))) + assert(UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", StructType(StructField("sub", IntegerType) :: Nil)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", MapType(nonBinaryStringType, IntegerType)) :: Nil))) + assert(!UnsafeRowUtils.isBinaryStable( + StructType(StructField("field", + StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala index 35a0d5ed8f5e1..146ba63cf402a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala @@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership +import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} -import org.apache.spark.sql.types.{StringType, StructField, StructType} -import org.apache.spark.unsafe.types.UTF8String +import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName @@ -641,199 +641,88 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { "reason" -> "generation expression cannot contain non-default collated string type")) } - trait ArrayCheck { - def dataType: String - def dataTypeCollated: String - } - - trait ArrayCheckSimple extends ArrayCheck { - override def dataType: String = "array" - override def dataTypeCollated: String = "array" - } - - trait ArrayCheckNested extends ArrayCheck { - override def dataType: String = "array>" - override def dataTypeCollated: String = "array>" - } - - test("Aggregation of arrays built on collated strings") { - abstract class AggCheck(val rows: Seq[String], val result: Seq[(Any, Int)]) extends ArrayCheck - - case class AggCheckSimple( - override val rows: Seq[String], - override val result: Seq[(Seq[String], Int)]) - extends AggCheck(rows, result) with ArrayCheckSimple - - case class AggCheckNested( - override val rows: Seq[String], - override val result: Seq[(Seq[Seq[String]], Int)]) - extends AggCheck(rows, result) with ArrayCheckNested - - val tableName = "test_agg_arr_collated" - val tableNameLowercase = "test_agg_arr_collated_lowercase" - - Seq( - // simple - AggCheckSimple( - rows = Seq("array('aaa')", "array('AAA')"), - result = Seq((Seq("aaa"), 2)) - ), - AggCheckSimple( - rows = Seq("array('aaa', 'bbb')", "array('AAA', 'BBB')"), - result = Seq((Seq("aaa", "bbb"), 2)) - ), - AggCheckSimple( - rows = Seq("array('aaa')", "array('bbb')", "array('AAA')", "array('BBB')"), - result = Seq((Seq("aaa"), 2), (Seq("bbb"), 2)) - ), - // nested - AggCheckNested( - rows = Seq("array(array('aaa'))", "array(array('AAA'))"), - result = Seq((Seq(Seq("aaa")), 2)) - ), - AggCheckNested( - rows = Seq("array(array('aaa'), array('bbb'))", "array(array('AAA'), array('bbb'))"), - result = Seq((Seq(Seq("aaa"), Seq("bbb")), 2)) - ), - AggCheckNested( - rows = Seq( - "array(array('aaa', 'aaa'), array('bbb', 'ccc'))", - "array(array('aaa', 'aaa'), array('bbb', 'ccc'), array('ddd'))", - "array(array('AAA', 'AAA'), array('BBB', 'CCC'))" - ), - result = Seq( - (Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc")), 2), - (Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc"), Seq("ddd")), 1) - ) + test("Aggregation on complex containing collated strings") { + val table = "table_agg" + // array + withTable(table) { + sql(s"create table $table (a array) using parquet") + sql(s"insert into $table values (array('aaa')), (array('AAA'))") + checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa")))) + } + // map doesn't support aggregation + withTable(table) { + sql(s"create table $table (m map) using parquet") + val query = s"select distinct m from $table" + checkError( + exception = intercept[ExtendedAnalysisException](sql(query)), + errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE", + parameters = Map( + "colName" -> "`m`", + "dataType" -> toSQLType(MapType( + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), + StringType))), + context = ExpectedContext(query, 0, query.length - 1) ) - ).map((check: AggCheck) => - withTable(tableName, tableNameLowercase) { - def checkResults(table: String): Unit = { - checkAnswer(sql(s"select a, count(*) from $table group by a"), - check.result.map{ case (agg, cnt) => Row(agg, cnt) }) - checkAnswer(sql(s"select distinct a from $table"), - check.result.map{ case (agg, _) => Row(agg) }) - } - - // check against non-binary collation - sql(s"create table $tableName(a ${check.dataTypeCollated}) using parquet") - check.rows.map(row => sql(s"insert into $tableName(a) values($row)")) - checkResults(tableName) - - // binary collation with values converted to lowercase should match the results as well - sql(s"create table $tableNameLowercase(a ${check.dataType}) using parquet") - check.rows.map(row => - // scalastyle:off caselocale - sql( - s""" - |insert into $tableNameLowercase(a) - |values(${UTF8String.fromString(row).toLowerCase}) - |""".stripMargin) - // scalastyle:on caselocale - ) - checkResults(tableNameLowercase) - } - ) + } + // struct + withTable(table) { + sql(s"create table $table (s struct) using parquet") + sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))") + checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa"))) + } } - test("Join on arrays of collated strings") { - abstract class JoinCheck( - val leftRows: Seq[String], - val rightRows: Seq[String], - val resultRows: Seq[Any]) - extends ArrayCheck - - case class JoinSimpleCheck( - override val leftRows: Seq[String], - override val rightRows: Seq[String], - override val resultRows: Seq[Seq[String]]) - extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckSimple - - case class JoinNestedCheck( - override val leftRows: Seq[String], - override val rightRows: Seq[String], - override val resultRows: Seq[Seq[Seq[String]]]) - extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckNested - - val tablePrefix = "test_join_arr_collated" - val tableLeft = s"${tablePrefix}_left" - val tableLeftLowercase = s"${tableLeft}_lowercase" - val tableRight = s"${tablePrefix}_right" - val tableRightLowercase = s"${tableRight}_lowercase" - - Seq( - // simple - JoinSimpleCheck( - leftRows = Seq("array('aaa')"), - rightRows = Seq("array('AAA')"), - resultRows = Seq(Seq("aaa")) - ), - JoinSimpleCheck( - leftRows = Seq("array('aaa', 'bbb')"), - rightRows = Seq("array('AAA', 'BBB')"), - resultRows = Seq(Seq("aaa", "bbb")) - ), - JoinSimpleCheck( - leftRows = Seq("array('aaa')", "array('bbb')"), - rightRows = Seq("array('AAA')", "array('BBB')"), - resultRows = Seq(Seq("aaa"), Seq("bbb")) - ), - JoinSimpleCheck( - leftRows = Seq("array('aaa')", "array('bbb')"), - rightRows = Seq("array('AAAA')", "array('BBBB')"), - resultRows = Seq() - ), - // nested - JoinNestedCheck( - leftRows = Seq("array(array('aaa'))"), - rightRows = Seq("array(array('AAA'))"), - resultRows = Seq(Seq(Seq("aaa"))) - ), - JoinNestedCheck( - leftRows = Seq("array(array('aaa', 'bbb'))"), - rightRows = Seq("array(array('AAA', 'BBB'))"), - resultRows = Seq(Seq(Seq("aaa", "bbb"))) - ), - JoinNestedCheck( - Seq("array(array('aaa'), array('bbb'))"), - Seq("array(array('AAA'), array('BBB'))"), - Seq(Seq(Seq("aaa"), Seq("bbb"))) - ), - JoinNestedCheck( - Seq("array(array('aaa'), array('bbb'))"), - Seq("array(array('AAA'), array('CCC'))"), - Seq() - ) - ).map((check: JoinCheck) => - withTable(tableLeft, tableLeftLowercase, tableRight, tableRightLowercase) { - def checkResults(left: String, right: String): Unit = { - checkAnswer( - sql(s"select $left.a from $left join $right where $left.a = $right.a"), - check.resultRows.map(Row(_)) - ) - } - - // check against non-binary collation - Seq(tableLeft, tableRight).map(tab => - sql(s"create table $tab(a ${check.dataTypeCollated}) using parquet")) - Seq((tableLeft, check.leftRows), (tableRight, check.rightRows)).foreach { - case (tab, data) => data.map(row => sql(s"insert into $tab(a) values($row)")) - } - checkResults(tableLeft, tableRight) - - // binary collation with values converted to lowercase should match the results as well - Seq(tableLeftLowercase, tableRightLowercase).map(tab => - sql(s"create table $tab(a ${check.dataType}) using parquet")) - Seq((tableLeftLowercase, check.leftRows), (tableRightLowercase, check.rightRows)).foreach { - case (tab, data) => - // scalastyle:off caselocale - data.map(row => - sql(s"insert into $tab(a) values(${UTF8String.fromString(row).toLowerCase})")) - // scalastyle:on caselocale - } - checkResults(tableLeft, tableRight) + test("Joins on complex types containing collated strings") { + val tableLeft = "table_join_le" + val tableRight = "table_join_ri" + // array + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (a array) using parquet")) + Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{ + case (tab, data) => sql(s"insert into $tab values ($data)") } - ) + checkAnswer(sql( + s""" + |select $tableLeft.a from $tableLeft + |join $tableRight on $tableLeft.a = $tableRight.a + |""".stripMargin), Seq(Row(Seq("aaa")))) + } + // map doesn't support joins + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (m map) using parquet")) + val query = + s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m" + val ctx = s"$tableLeft.m = $tableRight.m" + checkError( + exception = intercept[AnalysisException](sql(query)), + errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE", + parameters = Map( + "functionName" -> "`=`", + "dataType" -> toSQLType(MapType( + StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")), + StringType + )), + "sqlExpr" -> "\"(m = m)\""), + context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1)) + } + // struct + withTable(tableLeft, tableRight) { + Seq(tableLeft, tableRight).map(tab => + sql(s"create table $tab (s struct) using parquet")) + Seq( + (tableLeft, "named_struct('fld', 'aaa')"), + (tableRight, "named_struct('fld', 'AAA')") + ).map { + case (tab, data) => sql(s"insert into $tab values ($data)") + } + checkAnswer(sql( + s""" + |select $tableLeft.s.fld from $tableLeft + |join $tableRight on $tableLeft.s = $tableRight.s + |""".stripMargin), Seq(Row("aaa"))) + } } test("window aggregates should respect collation") {