From 490a4b3b1fdf47991b5a6588df14e63c3dd8b211 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Jun 2024 13:01:11 -0700 Subject: [PATCH] [SPARK-48498][SQL] Always do char padding in predicates ### What changes were proposed in this pull request? For some data sources, CHAR type padding is not applied on both the write and read sides (by disabling `spark.sql.readSideCharPadding`), as a different SQL flavor, which is similar to MySQL: https://dev.mysql.com/doc/refman/8.0/en/char.html However, there is a bug in Spark that we always pad the string literal when comparing CHAR type and STRING literals, which assumes the CHAR type columns are always padded, either on the write side or read side. This is not always true. This PR makes Spark always pad the CHAR type columns when comparing with string literals, to satisfy the CHAR type semantic. ### Why are the changes needed? bug fix if people disable read side char padding ### Does this PR introduce _any_ user-facing change? Yes. After this PR, comparing CHAR type with STRING literals follows the CHAR semantic, while before it mostly returns false. ### How was this patch tested? new tests ### Was this patch authored or co-authored using generative AI tooling? no Closes #46832 from cloud-fan/char. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../apache/spark/sql/internal/SQLConf.scala | 8 ++++ .../datasources/ApplyCharTypePadding.scala | 39 ++++++++++++++----- .../spark/sql/CharVarcharTestSuite.scala | 28 +++++++++++++ .../apache/spark/sql/PlanStabilitySuite.scala | 8 ++-- 4 files changed, 70 insertions(+), 13 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c4e584b9e31db..f4751f2027894 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -4616,6 +4616,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val LEGACY_NO_CHAR_PADDING_IN_PREDICATE = buildConf("spark.sql.legacy.noCharPaddingInPredicate") + .internal() + .doc("When true, Spark will not apply char type padding for CHAR type columns in string " + + s"comparison predicates, when '${READ_SIDE_CHAR_PADDING.key}' is false.") + .version("4.0.0") + .booleanConf + .createWithDefault(false) + val CLI_PRINT_HEADER = buildConf("spark.sql.cli.print.header") .doc("When set to true, spark-sql CLI prints the names of the columns in query output.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala index b5bf337a5a2e6..1b7b0d702ab98 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/ApplyCharTypePadding.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_COMPARISON, IN} import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{CharType, Metadata, StringType} import org.apache.spark.unsafe.types.UTF8String @@ -66,9 +67,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { r.copy(dataCols = cleanedDataCols, partitionCols = cleanedPartCols) }) } - paddingForStringComparison(newPlan) + paddingForStringComparison(newPlan, padCharCol = false) } else { - paddingForStringComparison(plan) + paddingForStringComparison( + plan, padCharCol = !conf.getConf(SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE)) } } @@ -90,7 +92,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def paddingForStringComparison(plan: LogicalPlan): LogicalPlan = { + private def paddingForStringComparison(plan: LogicalPlan, padCharCol: Boolean): LogicalPlan = { plan.resolveOperatorsUpWithPruning(_.containsAnyPattern(BINARY_COMPARISON, IN)) { case operator => operator.transformExpressionsUpWithPruning( _.containsAnyPattern(BINARY_COMPARISON, IN)) { @@ -99,12 +101,12 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { // String literal is treated as char type when it's compared to a char type column. // We should pad the shorter one to the longer length. case b @ BinaryComparison(e @ AttrOrOuterRef(attr), lit) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren) }.getOrElse(b) case b @ BinaryComparison(lit, e @ AttrOrOuterRef(attr)) if lit.foldable => - padAttrLitCmp(e, attr.metadata, lit).map { newChildren => + padAttrLitCmp(e, attr.metadata, padCharCol, lit).map { newChildren => b.withNewChildren(newChildren.reverse) }.getOrElse(b) @@ -117,9 +119,10 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { val literalCharLengths = literalChars.map(_.numChars()) val targetLen = (length +: literalCharLengths).max Some(i.copy( - value = addPadding(e, length, targetLen), + value = addPadding(e, length, targetLen, alwaysPad = padCharCol), list = list.zip(literalCharLengths).map { - case (lit, charLength) => addPadding(lit, charLength, targetLen) + case (lit, charLength) => + addPadding(lit, charLength, targetLen, alwaysPad = false) } ++ nulls.map(Literal.create(_, StringType)))) case _ => None }.getOrElse(i) @@ -162,6 +165,7 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { private def padAttrLitCmp( expr: Expression, metadata: Metadata, + padCharCol: Boolean, lit: Expression): Option[Seq[Expression]] = { if (expr.dataType == StringType) { CharVarcharUtils.getRawType(metadata).flatMap { @@ -174,7 +178,14 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { if (length < stringLitLen) { Some(Seq(StringRPad(expr, Literal(stringLitLen)), lit)) } else if (length > stringLitLen) { - Some(Seq(expr, StringRPad(lit, Literal(length)))) + val paddedExpr = if (padCharCol) { + StringRPad(expr, Literal(length)) + } else { + expr + } + Some(Seq(paddedExpr, StringRPad(lit, Literal(length)))) + } else if (padCharCol) { + Some(Seq(StringRPad(expr, Literal(length)), lit)) } else { None } @@ -186,7 +197,15 @@ object ApplyCharTypePadding extends Rule[LogicalPlan] { } } - private def addPadding(expr: Expression, charLength: Int, targetLength: Int): Expression = { - if (targetLength > charLength) StringRPad(expr, Literal(targetLength)) else expr + private def addPadding( + expr: Expression, + charLength: Int, + targetLength: Int, + alwaysPad: Boolean): Expression = { + if (targetLength > charLength) { + StringRPad(expr, Literal(targetLength)) + } else if (alwaysPad) { + StringRPad(expr, Literal(charLength)) + } else expr } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala index 013177425da78..a93dee3bf2a61 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CharVarcharTestSuite.scala @@ -942,6 +942,34 @@ class FileSourceCharVarcharTestSuite extends CharVarcharTestSuite with SharedSpa } } } + + test("SPARK-48498: always do char padding in predicates") { + import testImplicits._ + withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false") { + withTempPath { dir => + withTable("t") { + Seq( + "12" -> "12", + "12" -> "12 ", + "12 " -> "12", + "12 " -> "12 " + ).toDF("c1", "c2").write.format(format).save(dir.toString) + sql(s"CREATE TABLE t (c1 CHAR(3), c2 STRING) USING $format LOCATION '$dir'") + // Comparing CHAR column with STRING column directly compares the stored value. + checkAnswer( + sql("SELECT c1 = c2 FROM t"), + Seq(Row(true), Row(false), Row(false), Row(true)) + ) + // No matter the CHAR type value is padded or not in the storage, we should always pad it + // before comparison with STRING literals. + checkAnswer( + sql("SELECT c1 = '12', c1 = '12 ', c1 = '12 ' FROM t WHERE c2 = '12'"), + Seq(Row(true, true, true), Row(true, true, true)) + ) + } + } + } + } } class DSV2CharVarcharTestSuite extends CharVarcharTestSuite diff --git a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala index 34c6c49bc4981..ad424b3a7cc76 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/PlanStabilitySuite.scala @@ -256,9 +256,11 @@ trait PlanStabilitySuite extends DisableAdaptiveExecutionSuite { protected def testQuery(tpcdsGroup: String, query: String, suffix: String = ""): Unit = { val queryString = resourceToString(s"$tpcdsGroup/$query.sql", classLoader = Thread.currentThread().getContextClassLoader) - // Disable char/varchar read-side handling for better performance. - withSQLConf(SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", - SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { + withSQLConf( + // Disable char/varchar read-side handling for better performance. + SQLConf.READ_SIDE_CHAR_PADDING.key -> "false", + SQLConf.LEGACY_NO_CHAR_PADDING_IN_PREDICATE.key -> "true", + SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "10MB") { val qe = sql(queryString).queryExecution val plan = qe.executedPlan val explain = normalizeLocation(normalizeIds(qe.explainString(FormattedMode)))