Skip to content

Commit

Permalink
[HUDI-4851] Fixing CSI not handling InSet operator properly (apache…
Browse files Browse the repository at this point in the history
…#6685)

(cherry picked from commit 6e31b7c)
  • Loading branch information
Alexey Kudinkin authored and neverdizzy committed Dec 1, 2022
1 parent 630b015 commit 26dcc05
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hudi.common.util.ValidationUtils.checkState
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, InSet, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -61,7 +61,7 @@ object DataSkippingUtils extends Logging {
}
}

private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = {
private def tryComposeIndexFilterExpr(sourceFilterExpr: Expression, indexSchema: StructType): Option[Expression] = {
//
// For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're
// assuming that
Expand Down Expand Up @@ -91,7 +91,7 @@ object DataSkippingUtils extends Logging {
// colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA))
// colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA))
//
sourceExpr match {
sourceFilterExpr match {
// If Expression is not resolved, we can't perform the analysis accurately, bailing
case expr if !expr.resolved => None

Expand Down Expand Up @@ -227,6 +227,16 @@ object DataSkippingUtils extends Logging {
list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or)
}

// Filter "expr(colA) in (B1, B2, ...)"
// NOTE: [[InSet]] is an optimized version of the [[In]] expression, where every sub-expression w/in the
// set is a static literal
case InSet(sourceExpr @ AllowedTransformationExpression(attrRef), hset: Set[Any]) =>
getTargetIndexedColumnName(attrRef, indexSchema)
.map { colName =>
val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
hset.map(value => genColumnValuesEqualToExpression(colName, Literal(value), targetExprBuilder)).reduce(Or)
}

// Filter "expr(colA) not in (B1, B2, ...)"
// Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup
// NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..."
Expand Down Expand Up @@ -331,8 +341,8 @@ private object ColumnStatsExpressionUtils {
@inline def genColValueCountExpr: Expression = col(getValueCountColumnNameFor).expr

@inline def genColumnValuesEqualToExpression(colName: String,
value: Expression,
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
value: Expression,
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
// Only case when column C contains value V is when min(C) <= V <= max(c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.hudi
import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema
import org.apache.hudi.testutils.HoodieClientTestBase
import org.apache.spark.sql.HoodieCatalystExpressionUtils.resolveExpr
import org.apache.spark.sql.catalyst.expressions.{Expression, Not}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, InSet, Not}
import org.apache.spark.sql.functions.{col, lower}
import org.apache.spark.sql.hudi.DataSkippingUtils
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
Expand All @@ -34,6 +35,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource}

import java.sql.Timestamp
import scala.collection.JavaConverters._
import scala.collection.immutable.HashSet

// NOTE: Only A, B columns are indexed
case class IndexRow(fileName: String,
Expand Down Expand Up @@ -80,31 +82,38 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
val indexSchema: StructType = composeIndexSchema(indexedCols, sourceTableSchema)

@ParameterizedTest
@MethodSource(
Array(
"testBasicLookupFilterExpressionsSource",
"testAdvancedLookupFilterExpressionsSource",
"testCompositeFilterExpressionsSource"
))
def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = {
@MethodSource(Array(
"testBasicLookupFilterExpressionsSource",
"testAdvancedLookupFilterExpressionsSource",
"testCompositeFilterExpressionsSource"
))
def testLookupFilterExpressions(sourceFilterExprStr: String, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
// We have to fix the timezone to make sure all date-bound utilities output
// is consistent with the fixtures
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")

val resolvedExpr: Expression = resolveExpr(spark, sourceExpr, sourceTableSchema)
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
val resolvedFilterExpr: Expression = resolveExpr(spark, sourceFilterExprStr, sourceTableSchema)
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)

val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
assertEquals(expectedOutput, rows)
}

val rows = indexDf.where(new Column(lookupFilter))
.select("fileName")
.collect()
.map(_.getString(0))
.toSeq
@ParameterizedTest
@MethodSource(Array(
"testMiscLookupFilterExpressionsSource"
))
def testMiscLookupFilterExpressions(filterExpr: Expression, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
// We have to fix the timezone to make sure all date-bound utilities output
// is consistent with the fixtures
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")

assertEquals(output, rows)
val resolvedFilterExpr: Expression = resolveExpr(spark, filterExpr, sourceTableSchema)
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)

assertEquals(expectedOutput, rows)
}


@ParameterizedTest
@MethodSource(Array("testStringsLookupFilterExpressionsSource"))
def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = {
Expand All @@ -124,6 +133,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor

assertEquals(output, rows)
}

private def applyFilterExpr(resolvedExpr: Expression, input: Seq[IndexRow]): Seq[String] = {
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)

val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)

indexDf.where(new Column(lookupFilter))
.select("fileName")
.collect()
.map(_.getString(0))
.toSeq
}
}

object TestDataSkippingUtils {
Expand Down Expand Up @@ -159,6 +180,23 @@ object TestDataSkippingUtils {
)
}

def testMiscLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
// NOTE: Have to use [[Arrays.stream]], as Scala can't resolve properly 2 overloads for [[Stream.of]]
// (for single element)
java.util.Arrays.stream(
Array(
arguments(
InSet(UnresolvedAttribute("A"), HashSet(0, 1)),
Seq(
IndexRow("file_1", valueCount = 1, 1, 2, 0),
IndexRow("file_2", valueCount = 1, -1, 1, 0),
IndexRow("file_3", valueCount = 1, -2, -1, 0)
),
Seq("file_1", "file_2"))
)
)
}

def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
java.util.stream.Stream.of(
// TODO cases
Expand Down

0 comments on commit 26dcc05

Please sign in to comment.