Skip to content

Commit

Permalink
[SPARK-47692][SQL] Fix default StringType meaning in implicit casting
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Addition of priority flag to StringType.

### Why are the changes needed?
In order to follow casting rules for collations, we need to know whether StringType is considered default, implicit or explicit.

### Does this PR introduce _any_ user-facing change?
Yes.

### How was this patch tested?
Implicit tests in CollationSuite.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes apache#45819 from mihailom-db/SPARK-47692.

Authored-by: Mihailo Milosevic <mihailo.milosevic@databricks.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
mihailom-db authored and JacobZheng0927 committed May 11, 2024
1 parent 4a74e5f commit 17d0d7f
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@

package org.apache.spark.sql.internal.types

import org.apache.spark.sql.internal.SqlApiConf
import org.apache.spark.sql.types.{AbstractDataType, DataType, StringType}

/**
* StringTypeCollated is an abstract class for StringType with collation support.
*/
abstract class AbstractStringType extends AbstractDataType {
override private[sql] def defaultConcreteType: DataType = StringType
override private[sql] def defaultConcreteType: DataType = SqlApiConf.get.defaultStringType
override private[sql] def simpleString: String = "string"
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import javax.annotation.Nullable
import scala.annotation.tailrec

import org.apache.spark.sql.catalyst.analysis.TypeCoercion.{hasStringType, haveSameType}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Overlay, StringLPad, StringRPad}
import org.apache.spark.sql.catalyst.expressions.{ArrayJoin, BinaryExpression, CaseWhen, Cast, Coalesce, Collate, Concat, ConcatWs, CreateArray, Elt, Expression, Greatest, If, In, InSubquery, Least, Literal, Overlay, StringLPad, StringRPad}
import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
Expand All @@ -48,9 +48,9 @@ object CollationTypeCasts extends TypeCoercionRule {
case eltExpr: Elt =>
eltExpr.withNewChildren(eltExpr.children.head +: collateToSingleType(eltExpr.children.tail))

case overlay: Overlay =>
overlay.withNewChildren(collateToSingleType(Seq(overlay.input, overlay.replace))
++ Seq(overlay.pos, overlay.len))
case overlayExpr: Overlay =>
overlayExpr.withNewChildren(collateToSingleType(Seq(overlayExpr.input, overlayExpr.replace))
++ Seq(overlayExpr.pos, overlayExpr.len))

case stringPadExpr @ (_: StringRPad | _: StringLPad) =>
val Seq(str, len, pad) = stringPadExpr.children
Expand Down Expand Up @@ -108,7 +108,12 @@ object CollationTypeCasts extends TypeCoercionRule {
* complex DataTypes with collated StringTypes (e.g. ArrayType)
*/
def getOutputCollation(expr: Seq[Expression]): StringType = {
val explicitTypes = expr.filter(_.isInstanceOf[Collate])
val explicitTypes = expr.filter {
case _: Collate => true
case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isDefined =>
cast.dataType.isInstanceOf[StringType]
case _ => false
}
.map(_.dataType.asInstanceOf[StringType].collationId)
.distinct

Expand All @@ -123,17 +128,22 @@ object CollationTypeCasts extends TypeCoercionRule {
)
// Only implicit or default collations present
case 0 =>
val implicitTypes = expr.map(_.dataType)
val implicitTypes = expr.filter {
case Literal(_, _: StringType) => false
case cast: Cast if cast.getTagValue(Cast.USER_SPECIFIED_CAST).isEmpty =>
cast.child.dataType.isInstanceOf[StringType]
case _ => true
}
.map(_.dataType)
.filter(hasStringType)
.map(extractStringType)
.filter(dt => dt.collationId != SQLConf.get.defaultStringType.collationId)
.distinctBy(_.collationId)
.map(extractStringType(_).collationId)
.distinct

if (implicitTypes.length > 1) {
throw QueryCompilationErrors.implicitCollationMismatchError()
}
else {
implicitTypes.headOption.getOrElse(SQLConf.get.defaultStringType)
implicitTypes.headOption.map(StringType(_)).getOrElse(SQLConf.get.defaultStringType)
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -998,9 +998,10 @@ object TypeCoercion extends TypeCoercionBase {
case (_: StringType, AnyTimestampType) => AnyTimestampType.defaultConcreteType
case (_: StringType, BinaryType) => BinaryType
// Cast any atomic type to string.
case (any: AtomicType, _: StringType) if !any.isInstanceOf[StringType] => StringType
case (any: AtomicType, st: StringType) if !any.isInstanceOf[StringType] => st
case (any: AtomicType, st: AbstractStringType)
if !any.isInstanceOf[StringType] => st.defaultConcreteType
if !any.isInstanceOf[StringType] =>
st.defaultConcreteType

// When we reach here, input type is not acceptable for any types in this type collection,
// try to find the first one we can implicitly cast.
Expand Down
66 changes: 63 additions & 3 deletions sql/core/src/test/scala/org/apache/spark/sql/CollationSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import scala.jdk.CollectionConverters.MapHasAsJava

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.ExtendedAnalysisException
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.util.CollationFactory
import org.apache.spark.sql.connector.{DatasourceV2SQLBase, FakeV2ProviderWithCustomSchema}
import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTable}
Expand Down Expand Up @@ -412,7 +413,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("implicit casting of collated strings") {
test("SPARK-47210: Implicit casting of collated strings") {
val tableName = "parquet_dummy_implicit_cast_t22"
withTable(tableName) {
spark.sql(
Expand Down Expand Up @@ -566,7 +567,66 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}
}

test("cast of default collated strings in IN expression") {
test("SPARK-47692: Parameter marker with EXECUTE IMMEDIATE implicit casting") {
sql(s"DECLARE stmtStr1 = 'SELECT collation(:var1 || :var2)';")
sql(s"DECLARE stmtStr2 = 'SELECT collation(:var1 || (\\\'a\\\' COLLATE UNICODE))';")

checkAnswer(
sql(
"""EXECUTE IMMEDIATE stmtStr1 USING
| 'a' AS var1,
| 'b' AS var2;""".stripMargin),
Seq(Row("UTF8_BINARY"))
)

withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
checkAnswer(
sql(
"""EXECUTE IMMEDIATE stmtStr1 USING
| 'a' AS var1,
| 'b' AS var2;""".stripMargin),
Seq(Row("UNICODE"))
)
}

checkAnswer(
sql(
"""EXECUTE IMMEDIATE stmtStr2 USING
| 'a' AS var1;""".stripMargin),
Seq(Row("UNICODE"))
)

withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
checkAnswer(
sql(
"""EXECUTE IMMEDIATE stmtStr2 USING
| 'a' AS var1;""".stripMargin),
Seq(Row("UNICODE"))
)
}
}

test("SPARK-47692: Parameter markers with variable mapping") {
checkAnswer(
spark.sql(
"SELECT collation(:var1 || :var2)",
Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
"var2" -> Literal.create('b', StringType("UNICODE")))),
Seq(Row("UTF8_BINARY"))
)

withSQLConf(SqlApiConf.DEFAULT_COLLATION -> "UNICODE") {
checkAnswer(
spark.sql(
"SELECT collation(:var1 || :var2)",
Map("var1" -> Literal.create('a', StringType("UTF8_BINARY")),
"var2" -> Literal.create('b', StringType("UNICODE")))),
Seq(Row("UNICODE"))
)
}
}

test("SPARK-47210: Cast of default collated strings in IN expression") {
val tableName = "t1"
withTable(tableName) {
spark.sql(
Expand All @@ -591,7 +651,7 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper {
}

// TODO(SPARK-47210): Add indeterminate support
test("indeterminate collation checks") {
test("SPARK-47210: Indeterminate collation checks") {
val tableName = "t1"
val newTableName = "t2"
withTable(tableName) {
Expand Down

0 comments on commit 17d0d7f

Please sign in to comment.