Skip to content

Commit

Permalink
Refactor tests
Browse files Browse the repository at this point in the history
  • Loading branch information
nikolamand-db committed Mar 21, 2024
1 parent 5f59648 commit 514d2cd
Show file tree
Hide file tree
Showing 2 changed files with 147 additions and 192 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.math.{BigDecimal => JavaBigDecimal}

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.expressions.{SpecificInternalRow, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{Decimal, DecimalType, IntegerType, StringType, StructField, StructType}
import org.apache.spark.sql.types.{ArrayType, Decimal, DecimalType, IntegerType, MapType, StringType, StructField, StructType}

class UnsafeRowUtilsSuite extends SparkFunSuite {

Expand Down Expand Up @@ -91,4 +91,70 @@ class UnsafeRowUtilsSuite extends SparkFunSuite {
"fieldStatus:\n" +
"[UnsafeRowFieldStatus] index: 0, expectedFieldType: IntegerType,"))
}

test("isBinaryStable on complex types containing collated strings") {
val nonBinaryStringType = StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE"))

// simple checks
assert(UnsafeRowUtils.isBinaryStable(IntegerType))
assert(UnsafeRowUtils.isBinaryStable(StringType))
assert(!UnsafeRowUtils.isBinaryStable(nonBinaryStringType))

assert(UnsafeRowUtils.isBinaryStable(ArrayType(IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(StringType)))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(MapType(StringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, StringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(StringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, nonBinaryStringType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(nonBinaryStringType, IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(MapType(IntegerType, nonBinaryStringType)))

assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", IntegerType) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(StructType(StructField("field", StringType) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", nonBinaryStringType) :: Nil)))

// nested complex types
assert(UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(StringType))))
assert(UnsafeRowUtils.isBinaryStable(ArrayType(MapType(StringType, IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", StringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(MapType(IntegerType, nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(MapType(IntegerType, ArrayType(nonBinaryStringType)))))
assert(!UnsafeRowUtils.isBinaryStable(
ArrayType(StructType(StructField("field", nonBinaryStringType) :: Nil))))
assert(!UnsafeRowUtils.isBinaryStable(ArrayType(StructType(
Seq(StructField("second", IntegerType), StructField("second", nonBinaryStringType))))))

assert(UnsafeRowUtils.isBinaryStable(MapType(ArrayType(StringType), ArrayType(IntegerType))))
assert(UnsafeRowUtils.isBinaryStable(MapType(MapType(StringType, StringType), IntegerType)))
assert(UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", StringType) :: Nil), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(ArrayType(nonBinaryStringType), ArrayType(IntegerType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(IntegerType, ArrayType(nonBinaryStringType))))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(MapType(IntegerType, nonBinaryStringType), IntegerType)))
assert(!UnsafeRowUtils.isBinaryStable(
MapType(StructType(StructField("field", nonBinaryStringType) :: Nil), IntegerType)))

assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(StringType, IntegerType)) :: Nil)))
assert(UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", StructType(StructField("sub", IntegerType) :: Nil)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", ArrayType(nonBinaryStringType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field", MapType(nonBinaryStringType, IntegerType)) :: Nil)))
assert(!UnsafeRowUtils.isBinaryStable(
StructType(StructField("field",
StructType(StructField("sub", nonBinaryStringType) :: Nil)) :: Nil)))
}
}
271 changes: 80 additions & 191 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCu
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.CatalogHelper
import org.apache.spark.sql.connector.catalog.CatalogV2Util.withDefaultOwnership
import org.apache.spark.sql.errors.DataTypeErrors.toSQLType
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, ObjectHashAggregateExec}
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType}

class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
protected val v2Source = classOf[FakeV2ProviderWithCustomSchema].getName
Expand Down Expand Up @@ -641,199 +641,88 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
"reason" -> "generation expression cannot contain non-default collated string type"))
}

trait ArrayCheck {
def dataType: String
def dataTypeCollated: String
}

trait ArrayCheckSimple extends ArrayCheck {
override def dataType: String = "array<string>"
override def dataTypeCollated: String = "array<string collate utf8_binary_lcase>"
}

trait ArrayCheckNested extends ArrayCheck {
override def dataType: String = "array<array<string>>"
override def dataTypeCollated: String = "array<array<string collate utf8_binary_lcase>>"
}

test("Aggregation of arrays built on collated strings") {
abstract class AggCheck(val rows: Seq[String], val result: Seq[(Any, Int)]) extends ArrayCheck

case class AggCheckSimple(
override val rows: Seq[String],
override val result: Seq[(Seq[String], Int)])
extends AggCheck(rows, result) with ArrayCheckSimple

case class AggCheckNested(
override val rows: Seq[String],
override val result: Seq[(Seq[Seq[String]], Int)])
extends AggCheck(rows, result) with ArrayCheckNested

val tableName = "test_agg_arr_collated"
val tableNameLowercase = "test_agg_arr_collated_lowercase"

Seq(
// simple
AggCheckSimple(
rows = Seq("array('aaa')", "array('AAA')"),
result = Seq((Seq("aaa"), 2))
),
AggCheckSimple(
rows = Seq("array('aaa', 'bbb')", "array('AAA', 'BBB')"),
result = Seq((Seq("aaa", "bbb"), 2))
),
AggCheckSimple(
rows = Seq("array('aaa')", "array('bbb')", "array('AAA')", "array('BBB')"),
result = Seq((Seq("aaa"), 2), (Seq("bbb"), 2))
),
// nested
AggCheckNested(
rows = Seq("array(array('aaa'))", "array(array('AAA'))"),
result = Seq((Seq(Seq("aaa")), 2))
),
AggCheckNested(
rows = Seq("array(array('aaa'), array('bbb'))", "array(array('AAA'), array('bbb'))"),
result = Seq((Seq(Seq("aaa"), Seq("bbb")), 2))
),
AggCheckNested(
rows = Seq(
"array(array('aaa', 'aaa'), array('bbb', 'ccc'))",
"array(array('aaa', 'aaa'), array('bbb', 'ccc'), array('ddd'))",
"array(array('AAA', 'AAA'), array('BBB', 'CCC'))"
),
result = Seq(
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc")), 2),
(Seq(Seq("aaa", "aaa"), Seq("bbb", "ccc"), Seq("ddd")), 1)
)
test("Aggregation on complex containing collated strings") {
val table = "table_agg"
// array
withTable(table) {
sql(s"create table $table (a array<string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (array('aaa')), (array('AAA'))")
checkAnswer(sql(s"select distinct a from $table"), Seq(Row(Seq("aaa"))))
}
// map doesn't support aggregation
withTable(table) {
sql(s"create table $table (m map<string collate utf8_binary_lcase, string>) using parquet")
val query = s"select distinct m from $table"
checkError(
exception = intercept[ExtendedAnalysisException](sql(query)),
errorClass = "UNSUPPORTED_FEATURE.SET_OPERATION_ON_MAP_TYPE",
parameters = Map(
"colName" -> "`m`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType))),
context = ExpectedContext(query, 0, query.length - 1)
)
).map((check: AggCheck) =>
withTable(tableName, tableNameLowercase) {
def checkResults(table: String): Unit = {
checkAnswer(sql(s"select a, count(*) from $table group by a"),
check.result.map{ case (agg, cnt) => Row(agg, cnt) })
checkAnswer(sql(s"select distinct a from $table"),
check.result.map{ case (agg, _) => Row(agg) })
}

// check against non-binary collation
sql(s"create table $tableName(a ${check.dataTypeCollated}) using parquet")
check.rows.map(row => sql(s"insert into $tableName(a) values($row)"))
checkResults(tableName)

// binary collation with values converted to lowercase should match the results as well
sql(s"create table $tableNameLowercase(a ${check.dataType}) using parquet")
check.rows.map(row =>
// scalastyle:off caselocale
sql(
s"""
|insert into $tableNameLowercase(a)
|values(${UTF8String.fromString(row).toLowerCase})
|""".stripMargin)
// scalastyle:on caselocale
)
checkResults(tableNameLowercase)
}
)
}
// struct
withTable(table) {
sql(s"create table $table (s struct<fld:string collate utf8_binary_lcase>) using parquet")
sql(s"insert into $table values (named_struct('fld', 'aaa')), (named_struct('fld', 'AAA'))")
checkAnswer(sql(s"select s.fld from $table group by s"), Seq(Row("aaa")))
}
}

test("Join on arrays of collated strings") {
abstract class JoinCheck(
val leftRows: Seq[String],
val rightRows: Seq[String],
val resultRows: Seq[Any])
extends ArrayCheck

case class JoinSimpleCheck(
override val leftRows: Seq[String],
override val rightRows: Seq[String],
override val resultRows: Seq[Seq[String]])
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckSimple

case class JoinNestedCheck(
override val leftRows: Seq[String],
override val rightRows: Seq[String],
override val resultRows: Seq[Seq[Seq[String]]])
extends JoinCheck(leftRows, rightRows, resultRows) with ArrayCheckNested

val tablePrefix = "test_join_arr_collated"
val tableLeft = s"${tablePrefix}_left"
val tableLeftLowercase = s"${tableLeft}_lowercase"
val tableRight = s"${tablePrefix}_right"
val tableRightLowercase = s"${tableRight}_lowercase"

Seq(
// simple
JoinSimpleCheck(
leftRows = Seq("array('aaa')"),
rightRows = Seq("array('AAA')"),
resultRows = Seq(Seq("aaa"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa', 'bbb')"),
rightRows = Seq("array('AAA', 'BBB')"),
resultRows = Seq(Seq("aaa", "bbb"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa')", "array('bbb')"),
rightRows = Seq("array('AAA')", "array('BBB')"),
resultRows = Seq(Seq("aaa"), Seq("bbb"))
),
JoinSimpleCheck(
leftRows = Seq("array('aaa')", "array('bbb')"),
rightRows = Seq("array('AAAA')", "array('BBBB')"),
resultRows = Seq()
),
// nested
JoinNestedCheck(
leftRows = Seq("array(array('aaa'))"),
rightRows = Seq("array(array('AAA'))"),
resultRows = Seq(Seq(Seq("aaa")))
),
JoinNestedCheck(
leftRows = Seq("array(array('aaa', 'bbb'))"),
rightRows = Seq("array(array('AAA', 'BBB'))"),
resultRows = Seq(Seq(Seq("aaa", "bbb")))
),
JoinNestedCheck(
Seq("array(array('aaa'), array('bbb'))"),
Seq("array(array('AAA'), array('BBB'))"),
Seq(Seq(Seq("aaa"), Seq("bbb")))
),
JoinNestedCheck(
Seq("array(array('aaa'), array('bbb'))"),
Seq("array(array('AAA'), array('CCC'))"),
Seq()
)
).map((check: JoinCheck) =>
withTable(tableLeft, tableLeftLowercase, tableRight, tableRightLowercase) {
def checkResults(left: String, right: String): Unit = {
checkAnswer(
sql(s"select $left.a from $left join $right where $left.a = $right.a"),
check.resultRows.map(Row(_))
)
}

// check against non-binary collation
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab(a ${check.dataTypeCollated}) using parquet"))
Seq((tableLeft, check.leftRows), (tableRight, check.rightRows)).foreach {
case (tab, data) => data.map(row => sql(s"insert into $tab(a) values($row)"))
}
checkResults(tableLeft, tableRight)

// binary collation with values converted to lowercase should match the results as well
Seq(tableLeftLowercase, tableRightLowercase).map(tab =>
sql(s"create table $tab(a ${check.dataType}) using parquet"))
Seq((tableLeftLowercase, check.leftRows), (tableRightLowercase, check.rightRows)).foreach {
case (tab, data) =>
// scalastyle:off caselocale
data.map(row =>
sql(s"insert into $tab(a) values(${UTF8String.fromString(row).toLowerCase})"))
// scalastyle:on caselocale
}
checkResults(tableLeft, tableRight)
test("Joins on complex types containing collated strings") {
val tableLeft = "table_join_le"
val tableRight = "table_join_ri"
// array
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (a array<string collate utf8_binary_lcase>) using parquet"))
Seq((tableLeft, "array('aaa')"), (tableRight, "array('AAA')")).map{
case (tab, data) => sql(s"insert into $tab values ($data)")
}
)
checkAnswer(sql(
s"""
|select $tableLeft.a from $tableLeft
|join $tableRight on $tableLeft.a = $tableRight.a
|""".stripMargin), Seq(Row(Seq("aaa"))))
}
// map doesn't support joins
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (m map<string collate utf8_binary_lcase, string>) using parquet"))
val query =
s"select $tableLeft.m from $tableLeft join $tableRight on $tableLeft.m = $tableRight.m"
val ctx = s"$tableLeft.m = $tableRight.m"
checkError(
exception = intercept[AnalysisException](sql(query)),
errorClass = "DATATYPE_MISMATCH.INVALID_ORDERING_TYPE",
parameters = Map(
"functionName" -> "`=`",
"dataType" -> toSQLType(MapType(
StringType(CollationFactory.collationNameToId("UTF8_BINARY_LCASE")),
StringType
)),
"sqlExpr" -> "\"(m = m)\""),
context = ExpectedContext(ctx, query.length - ctx.length, query.length - 1))
}
// struct
withTable(tableLeft, tableRight) {
Seq(tableLeft, tableRight).map(tab =>
sql(s"create table $tab (s struct<fld:string collate utf8_binary_lcase>) using parquet"))
Seq(
(tableLeft, "named_struct('fld', 'aaa')"),
(tableRight, "named_struct('fld', 'AAA')")
).map {
case (tab, data) => sql(s"insert into $tab values ($data)")
}
checkAnswer(sql(
s"""
|select $tableLeft.s.fld from $tableLeft
|join $tableRight on $tableLeft.s = $tableRight.s
|""".stripMargin), Seq(Row("aaa")))
}
}

test("window aggregates should respect collation") {
Expand Down

0 comments on commit 514d2cd

Please sign in to comment.