From c13932060cb19781537c28a67023f08942180b81 Mon Sep 17 00:00:00 2001 From: Alexey Kudinkin Date: Thu, 15 Sep 2022 18:45:02 -0700 Subject: [PATCH] [HUDI-4851] Fixing CSI not handling `InSet` operator properly (#6685) --- .../spark/sql/hudi/DataSkippingUtils.scala | 20 ++++-- .../apache/hudi/TestDataSkippingUtils.scala | 72 ++++++++++++++----- 2 files changed, 70 insertions(+), 22 deletions(-) diff --git a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala index 6a04ec57e1127..0fe62da0ded36 100644 --- a/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala +++ b/hudi-spark-datasource/hudi-spark-common/src/main/scala/org/apache/spark/sql/hudi/DataSkippingUtils.scala @@ -23,7 +23,7 @@ import org.apache.hudi.common.util.ValidationUtils.checkState import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral -import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, InSet, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression} import org.apache.spark.sql.functions.col import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._ import org.apache.spark.sql.types.StructType @@ -61,7 +61,7 @@ object DataSkippingUtils extends Logging { } } - private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = { + private def tryComposeIndexFilterExpr(sourceFilterExpr: Expression, indexSchema: StructType): Option[Expression] = { // // For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're // assuming that @@ -91,7 +91,7 @@ object DataSkippingUtils extends Logging { // colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA)) // colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA)) // - sourceExpr match { + sourceFilterExpr match { // If Expression is not resolved, we can't perform the analysis accurately, bailing case expr if !expr.resolved => None @@ -227,6 +227,16 @@ object DataSkippingUtils extends Logging { list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or) } + // Filter "expr(colA) in (B1, B2, ...)" + // NOTE: [[InSet]] is an optimized version of the [[In]] expression, where every sub-expression w/in the + // set is a static literal + case InSet(sourceExpr @ AllowedTransformationExpression(attrRef), hset: Set[Any]) => + getTargetIndexedColumnName(attrRef, indexSchema) + .map { colName => + val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _) + hset.map(value => genColumnValuesEqualToExpression(colName, Literal(value), targetExprBuilder)).reduce(Or) + } + // Filter "expr(colA) not in (B1, B2, ...)" // Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup // NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..." @@ -331,8 +341,8 @@ private object ColumnStatsExpressionUtils { @inline def genColValueCountExpr: Expression = col(getValueCountColumnNameFor).expr @inline def genColumnValuesEqualToExpression(colName: String, - value: Expression, - targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = { + value: Expression, + targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = { val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName)) val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName)) // Only case when column C contains value V is when min(C) <= V <= max(c) diff --git a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala index 95d784b6532a2..da3fd52e97e12 100644 --- a/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala +++ b/hudi-spark-datasource/hudi-spark/src/test/scala/org/apache/hudi/TestDataSkippingUtils.scala @@ -20,7 +20,8 @@ package org.apache.hudi import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema import org.apache.hudi.testutils.HoodieClientTestBase import org.apache.spark.sql.HoodieCatalystExpressionUtils.resolveExpr -import org.apache.spark.sql.catalyst.expressions.{Expression, Not} +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute +import org.apache.spark.sql.catalyst.expressions.{Expression, InSet, Not} import org.apache.spark.sql.functions.{col, lower} import org.apache.spark.sql.hudi.DataSkippingUtils import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE @@ -34,6 +35,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource} import java.sql.Timestamp import scala.collection.JavaConverters._ +import scala.collection.immutable.HashSet // NOTE: Only A, B columns are indexed case class IndexRow(fileName: String, @@ -80,31 +82,38 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor val indexSchema: StructType = composeIndexSchema(indexedCols, sourceTableSchema) @ParameterizedTest - @MethodSource( - Array( - "testBasicLookupFilterExpressionsSource", - "testAdvancedLookupFilterExpressionsSource", - "testCompositeFilterExpressionsSource" - )) - def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = { + @MethodSource(Array( + "testBasicLookupFilterExpressionsSource", + "testAdvancedLookupFilterExpressionsSource", + "testCompositeFilterExpressionsSource" + )) + def testLookupFilterExpressions(sourceFilterExprStr: String, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = { // We have to fix the timezone to make sure all date-bound utilities output // is consistent with the fixtures spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC") - val resolvedExpr: Expression = resolveExpr(spark, sourceExpr, sourceTableSchema) - val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema) + val resolvedFilterExpr: Expression = resolveExpr(spark, sourceFilterExprStr, sourceTableSchema) + val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input) - val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema) + assertEquals(expectedOutput, rows) + } - val rows = indexDf.where(new Column(lookupFilter)) - .select("fileName") - .collect() - .map(_.getString(0)) - .toSeq + @ParameterizedTest + @MethodSource(Array( + "testMiscLookupFilterExpressionsSource" + )) + def testMiscLookupFilterExpressions(filterExpr: Expression, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = { + // We have to fix the timezone to make sure all date-bound utilities output + // is consistent with the fixtures + spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC") - assertEquals(output, rows) + val resolvedFilterExpr: Expression = resolveExpr(spark, filterExpr, sourceTableSchema) + val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input) + + assertEquals(expectedOutput, rows) } + @ParameterizedTest @MethodSource(Array("testStringsLookupFilterExpressionsSource")) def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = { @@ -124,6 +133,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor assertEquals(output, rows) } + + private def applyFilterExpr(resolvedExpr: Expression, input: Seq[IndexRow]): Seq[String] = { + val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema) + + val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema) + + indexDf.where(new Column(lookupFilter)) + .select("fileName") + .collect() + .map(_.getString(0)) + .toSeq + } } object TestDataSkippingUtils { @@ -159,6 +180,23 @@ object TestDataSkippingUtils { ) } + def testMiscLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = { + // NOTE: Have to use [[Arrays.stream]], as Scala can't resolve properly 2 overloads for [[Stream.of]] + // (for single element) + java.util.Arrays.stream( + Array( + arguments( + InSet(UnresolvedAttribute("A"), HashSet(0, 1)), + Seq( + IndexRow("file_1", valueCount = 1, 1, 2, 0), + IndexRow("file_2", valueCount = 1, -1, 1, 0), + IndexRow("file_3", valueCount = 1, -2, -1, 0) + ), + Seq("file_1", "file_2")) + ) + ) + } + def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = { java.util.stream.Stream.of( // TODO cases