diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala index 8edb8ba51d828..44d7467e95ee8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceUtils.scala @@ -28,7 +28,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.{SparkException, SparkUpgradeException} import org.apache.spark.sql.{SPARK_LEGACY_DATETIME_METADATA_KEY, SPARK_LEGACY_INT96_METADATA_KEY, SPARK_TIMEZONE_METADATA_KEY, SPARK_VERSION_METADATA_KEY} import org.apache.spark.sql.catalyst.catalog.{CatalogTable, CatalogUtils} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, AttributeSet, Expression, ExpressionSet, IsNotNull, IsNull, PredicateHelper} import org.apache.spark.sql.catalyst.util.RebaseDateTime import org.apache.spark.sql.catalyst.util.RebaseDateTime.RebaseSpec import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} @@ -287,18 +287,17 @@ object DataSourceUtils extends PredicateHelper { * @return A boolean indicating whether the filter should be pushed down or not. */ def shouldPushFilter(expression: Expression): Boolean = { - def shouldPushFilterRecursive(expression: Expression): Boolean = expression match { - case attr: AttributeReference => - attr.dataType match { + def checkRecursive(expression: Expression): Boolean = expression match { + case _: IsNull | _: IsNotNull => true + case _ => + expression.dataType match { // don't push down filters for string columns with non-default collation // as it could lead to incorrect results case st: StringType => st.isDefaultCollation - case _ => true + case _ => expression.children.forall(checkRecursive) } - - case _ => expression.children.forall(shouldPushFilterRecursive) } - expression.deterministic && shouldPushFilterRecursive(expression) + expression.deterministic && checkRecursive(expression) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala index c85b0d9adae35..ed13b3034cd7a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala @@ -2210,10 +2210,16 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared } test("disable filter pushdown for collated strings") { + def containsFilters(df: DataFrame, filterString: String): Unit = { + val explain = df.queryExecution.explainString(ExplainMode.fromString("extended")) + assert(explain.contains(filterString)) + } + withTempPath { path => val collation = "'SR_CI_AI'" val df = sql( - s""" SELECT collate(c, $collation) as c + s""" SELECT collate(c, $collation) as c1, + |named_struct('f1', named_struct('f2', collate(c, $collation), 'f3', 1)) as ns |FROM VALUES ('aaa'), ('AAA'), ('bbb') |as data(c) |""".stripMargin) @@ -2231,11 +2237,24 @@ abstract class ParquetFilterSuite extends QueryTest with ParquetTest with Shared filters.foreach { filter => val readback = spark.read.parquet(path.getAbsolutePath) - .where(s"c ${filter._1} collate('aaa', $collation)") - val explain = readback.queryExecution.explainString(ExplainMode.fromString("extended")) - assert(explain.contains("PushedFilters: []")) + .where(s"c1 ${filter._1} collate('aaa', $collation)") + .where(s"ns.f1.f2 ${filter._1} collate('aaa', $collation)") + .where(s"ns ${filter._1} " + + s"named_struct('f1', named_struct('f2', collate('aaa', $collation), 'f3', 1))") + .select("c1") + + containsFilters(readback, + "PushedFilters: [IsNotNull(c1), IsNotNull(ns.f1.f2), IsNotNull(ns)]") checkAnswer(readback, filter._2) } + + // should still push down the filter for the nested column which is not collated + val readback = spark.read.parquet(path.getAbsolutePath) + .where(s"ns.f1.f3 == 1") + .select("c1") + + containsFilters(readback, "PushedFilters: [IsNotNull(ns.f1.f3), EqualTo(ns.f1.f3,1)]") + checkAnswer(readback, Seq(Row("aaa"), Row("AAA"), Row("bbb"))) } } }