Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[HUDI-4851] Fixing CSI not handling InSet operator properly #6685

Merged
merged 3 commits into from
Sep 16, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.hudi.common.util.ValidationUtils.checkState
import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.Literal.TrueLiteral
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Attribute, AttributeReference, EqualNullSafe, EqualTo, Expression, ExtractValue, GetStructField, GreaterThan, GreaterThanOrEqual, In, InSet, IsNotNull, IsNull, LessThan, LessThanOrEqual, Literal, Not, Or, StartsWith, SubqueryExpression}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.hudi.ColumnStatsExpressionUtils._
import org.apache.spark.sql.types.StructType
Expand Down Expand Up @@ -61,7 +61,7 @@ object DataSkippingUtils extends Logging {
}
}

private def tryComposeIndexFilterExpr(sourceExpr: Expression, indexSchema: StructType): Option[Expression] = {
private def tryComposeIndexFilterExpr(sourceFilterExpr: Expression, indexSchema: StructType): Option[Expression] = {
//
// For translation of the Filter Expression for the Data Table into Filter Expression for Column Stats Index, we're
// assuming that
Expand Down Expand Up @@ -91,7 +91,7 @@ object DataSkippingUtils extends Logging {
// colA_minValue = min(colA) => transform_expr(colA_minValue) = min(transform_expr(colA))
// colA_maxValue = max(colA) => transform_expr(colA_maxValue) = max(transform_expr(colA))
//
sourceExpr match {
sourceFilterExpr match {
// If Expression is not resolved, we can't perform the analysis accurately, bailing
case expr if !expr.resolved => None

Expand Down Expand Up @@ -227,6 +227,16 @@ object DataSkippingUtils extends Logging {
list.map(lit => genColumnValuesEqualToExpression(colName, lit, targetExprBuilder)).reduce(Or)
}

// Filter "expr(colA) in (B1, B2, ...)"
// NOTE: [[InSet]] is an optimized version of the [[In]] expression, where every sub-expression w/in the
// set is a static literal
case InSet(sourceExpr @ AllowedTransformationExpression(attrRef), hset: Set[Any]) =>
getTargetIndexedColumnName(attrRef, indexSchema)
.map { colName =>
val targetExprBuilder: Expression => Expression = swapAttributeRefInExpr(sourceExpr, attrRef, _)
hset.map(value => genColumnValuesEqualToExpression(colName, Literal(value), targetExprBuilder)).reduce(Or)
}

// Filter "expr(colA) not in (B1, B2, ...)"
// Translates to "NOT((colA_minValue = B1 AND colA_maxValue = B1) OR (colA_minValue = B2 AND colA_maxValue = B2))" for index lookup
// NOTE: This is NOT an inversion of `in (B1, B2, ...)` expr, this is equivalent to "colA != B1 AND colA != B2 AND ..."
Expand Down Expand Up @@ -331,8 +341,8 @@ private object ColumnStatsExpressionUtils {
@inline def genColValueCountExpr: Expression = col(getValueCountColumnNameFor).expr

@inline def genColumnValuesEqualToExpression(colName: String,
value: Expression,
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
value: Expression,
targetExprBuilder: Function[Expression, Expression] = Predef.identity): Expression = {
val minValueExpr = targetExprBuilder.apply(genColMinValueExpr(colName))
val maxValueExpr = targetExprBuilder.apply(genColMaxValueExpr(colName))
// Only case when column C contains value V is when min(C) <= V <= max(c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ package org.apache.hudi
import org.apache.hudi.ColumnStatsIndexSupport.composeIndexSchema
import org.apache.hudi.testutils.HoodieClientTestBase
import org.apache.spark.sql.HoodieCatalystExpressionUtils.resolveExpr
import org.apache.spark.sql.catalyst.expressions.{Expression, Not}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Expression, InSet, Not}
import org.apache.spark.sql.functions.{col, lower}
import org.apache.spark.sql.hudi.DataSkippingUtils
import org.apache.spark.sql.internal.SQLConf.SESSION_LOCAL_TIMEZONE
Expand All @@ -34,6 +35,7 @@ import org.junit.jupiter.params.provider.{Arguments, MethodSource}

import java.sql.Timestamp
import scala.collection.JavaConverters._
import scala.collection.immutable.HashSet

// NOTE: Only A, B columns are indexed
case class IndexRow(fileName: String,
Expand Down Expand Up @@ -80,31 +82,38 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor
val indexSchema: StructType = composeIndexSchema(indexedCols, sourceTableSchema)

@ParameterizedTest
@MethodSource(
Array(
"testBasicLookupFilterExpressionsSource",
"testAdvancedLookupFilterExpressionsSource",
"testCompositeFilterExpressionsSource"
))
def testLookupFilterExpressions(sourceExpr: String, input: Seq[IndexRow], output: Seq[String]): Unit = {
@MethodSource(Array(
"testBasicLookupFilterExpressionsSource",
"testAdvancedLookupFilterExpressionsSource",
"testCompositeFilterExpressionsSource"
))
def testLookupFilterExpressions(sourceFilterExprStr: String, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
// We have to fix the timezone to make sure all date-bound utilities output
// is consistent with the fixtures
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")

val resolvedExpr: Expression = resolveExpr(spark, sourceExpr, sourceTableSchema)
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)
val resolvedFilterExpr: Expression = resolveExpr(spark, sourceFilterExprStr, sourceTableSchema)
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)

val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)
assertEquals(expectedOutput, rows)
}

val rows = indexDf.where(new Column(lookupFilter))
.select("fileName")
.collect()
.map(_.getString(0))
.toSeq
@ParameterizedTest
@MethodSource(Array(
"testMiscLookupFilterExpressionsSource"
))
def testMiscLookupFilterExpressions(filterExpr: Expression, input: Seq[IndexRow], expectedOutput: Seq[String]): Unit = {
// We have to fix the timezone to make sure all date-bound utilities output
// is consistent with the fixtures
spark.sqlContext.setConf(SESSION_LOCAL_TIMEZONE.key, "UTC")

assertEquals(output, rows)
val resolvedFilterExpr: Expression = resolveExpr(spark, filterExpr, sourceTableSchema)
val rows: Seq[String] = applyFilterExpr(resolvedFilterExpr, input)

assertEquals(expectedOutput, rows)
}


@ParameterizedTest
@MethodSource(Array("testStringsLookupFilterExpressionsSource"))
def testStringsLookupFilterExpressions(sourceExpr: Expression, input: Seq[IndexRow], output: Seq[String]): Unit = {
Expand All @@ -124,6 +133,18 @@ class TestDataSkippingUtils extends HoodieClientTestBase with SparkAdapterSuppor

assertEquals(output, rows)
}

private def applyFilterExpr(resolvedExpr: Expression, input: Seq[IndexRow]): Seq[String] = {
val lookupFilter = DataSkippingUtils.translateIntoColumnStatsIndexFilterExpr(resolvedExpr, indexSchema)

val indexDf = spark.createDataFrame(input.map(_.toRow).asJava, indexSchema)

indexDf.where(new Column(lookupFilter))
.select("fileName")
.collect()
.map(_.getString(0))
.toSeq
}
}

object TestDataSkippingUtils {
Expand Down Expand Up @@ -159,6 +180,23 @@ object TestDataSkippingUtils {
)
}

def testMiscLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
// NOTE: Have to use [[Arrays.stream]], as Scala can't resolve properly 2 overloads for [[Stream.of]]
// (for single element)
java.util.Arrays.stream(
Array(
arguments(
InSet(UnresolvedAttribute("A"), HashSet(0, 1)),
Seq(
IndexRow("file_1", valueCount = 1, 1, 2, 0),
IndexRow("file_2", valueCount = 1, -1, 1, 0),
IndexRow("file_3", valueCount = 1, -2, -1, 0)
),
Seq("file_1", "file_2"))
)
)
}

def testBasicLookupFilterExpressionsSource(): java.util.stream.Stream[Arguments] = {
java.util.stream.Stream.of(
// TODO cases
Expand Down