Skip to content

Commit

Permalink
Provide hybrid scan setting for consistency requirement (#1819)
Browse files Browse the repository at this point in the history
* Add UT for file index

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add hybrid scan config and IT

Signed-off-by: Chen Dai <daichen@amazon.com>

* Implement select file logic for hybrid scan mode

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add IT

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jul 11, 2023
1 parent 91b2a06 commit 5da4f0a
Show file tree
Hide file tree
Showing 7 changed files with 235 additions and 24 deletions.
1 change: 1 addition & 0 deletions flint/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.flint.optimizer.enabled`: default is true.
- `spark.flint.index.hybridscan.enabled`: default is false.

#### Data Type Mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ object FlintSparkConf {
val OPTIMIZER_RULE_ENABLED = FlintConfig("spark.flint.optimizer.enabled")
.doc("Enable Flint optimizer rule for query rewrite with Flint index")
.createWithDefault("true")

val HYBRID_SCAN_ENABLED = FlintConfig("spark.flint.index.hybridscan.enabled")
.doc("Enable hybrid scan to include latest source data not refreshed to index yet")
.createWithDefault("false")
}

/**
Expand All @@ -114,6 +118,8 @@ class FlintSparkConf(properties: JMap[String, String]) extends Serializable {

def isOptimizerEnabled: Boolean = OPTIMIZER_RULE_ENABLED.readFrom(reader).toBoolean

def isHybridScanEnabled: Boolean = HYBRID_SCAN_ENABLED.readFrom(reader).toBoolean

/**
* spark.sql.session.timeZone
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
package org.opensearch.flint.spark.skipping

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.{And, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -38,7 +38,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
val index = flint.describeIndex(indexName)
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val indexPred = rewriteToIndexPredicate(skippingIndex, condition)
val indexFilter = rewriteToIndexFilter(skippingIndex, condition)

/*
* Replace original file index with Flint skipping file index:
Expand All @@ -47,9 +47,9 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
* |- HadoopFsRelation
* |- FileIndex <== replaced with FlintSkippingFileIndex
*/
if (indexPred.isDefined) {
val filterByIndex = buildFilterIndexQuery(skippingIndex, indexPred.get)
val fileIndex = new FlintSparkSkippingFileIndex(location, filterByIndex)
if (indexFilter.isDefined) {
val indexScan = buildIndexScan(skippingIndex)
val fileIndex = FlintSparkSkippingFileIndex(location, indexScan, indexFilter.get)
val indexRelation = baseRelation.copy(location = fileIndex)(baseRelation.sparkSession)
filter.copy(child = relation.copy(relation = indexRelation))
} else {
Expand All @@ -60,7 +60,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
}
}

private def rewriteToIndexPredicate(
private def rewriteToIndexFilter(
index: FlintSparkSkippingIndex,
condition: Predicate): Option[Predicate] = {

Expand All @@ -71,15 +71,9 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
.reduceOption(And(_, _))
}

private def buildFilterIndexQuery(
index: FlintSparkSkippingIndex,
rewrittenPredicate: Predicate): DataFrame = {

// Get file list based on the rewritten predicates on index data
private def buildIndexScan(index: FlintSparkSkippingIndex): DataFrame = {
flint.spark.read
.format(FLINT_DATASOURCE)
.load(index.name())
.filter(new Column(rewrittenPredicate))
.select(FILE_PATH_COLUMN)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,43 @@
package org.opensearch.flint.spark.skipping

import org.apache.hadoop.fs.{FileStatus, Path}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COLUMN

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{Expression, Predicate}
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.functions.isnull
import org.apache.spark.sql.types.StructType

/**
* File index that skips source files based on the selected files by Flint skipping index.
*
* @param baseFileIndex
* original file index
* @param filterByIndex
* pushed down filtering on index data
* @param indexScan
* query skipping index DF with pushed down filters
*/
case class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex: DataFrame)
case class FlintSparkSkippingFileIndex(
baseFileIndex: FileIndex,
indexScan: DataFrame,
indexFilter: Predicate)
extends FileIndex {

override def listFiles(
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[PartitionDirectory] = {

// TODO: make this listFile call only in hybrid scan mode
val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
val selectedFiles =
filterByIndex.collect
.map(_.getString(0))
.toSet
if (FlintSparkConf().isHybridScanEnabled) {
selectFilesFromIndexAndSource(partitions)
} else {
selectFilesFromIndexOnly()
}

val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
// Keep partition files present in selected file list above
partitions
.map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f))))
.filter(p => p.files.nonEmpty)
Expand All @@ -48,6 +58,44 @@ case class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex:

override def partitionSchema: StructType = baseFileIndex.partitionSchema

/*
* Left join source partitions and index data to keep unknown source files:
* Express the logic in SQL:
* SELECT left.file_path
* FROM partitions AS left
* LEFT JOIN indexScan AS right
* ON left.file_path = right.file_path
* WHERE right.file_path IS NULL
* OR [indexFilter]
*/
private def selectFilesFromIndexAndSource(partitions: Seq[PartitionDirectory]): Set[String] = {
val sparkSession = indexScan.sparkSession
import sparkSession.implicits._

partitions
.flatMap(_.files.map(f => f.getPath.toUri.toString))
.toDF(FILE_PATH_COLUMN)
.join(indexScan, Seq(FILE_PATH_COLUMN), "left")
.filter(isnull(indexScan(FILE_PATH_COLUMN)) || new Column(indexFilter))
.select(FILE_PATH_COLUMN)
.collect()
.map(_.getString(0))
.toSet
}

/*
* Consider file paths in index data alone. In this case, index filter can be pushed down
* to index store.
*/
private def selectFilesFromIndexOnly(): Set[String] = {
indexScan
.filter(new Column(indexFilter))
.select(FILE_PATH_COLUMN)
.collect
.map(_.getString(0))
.toSet
}

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.opensearch.flint.spark.FlintSparkExtensions
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.flint.config.FlintConfigEntry
import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand All @@ -34,4 +35,13 @@ trait FlintSuite extends SharedSparkSession {
protected def setFlintSparkConf[T](config: FlintConfigEntry[T], value: Any): Unit = {
spark.conf.set(config.key, value.toString)
}

protected def withHybridScanEnabled(block: => Unit): Unit = {
setFlintSparkConf(HYBRID_SCAN_ENABLED, "true")
try {
block
} finally {
setFlintSparkConf(HYBRID_SCAN_ENABLED, "false")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.apache.hadoop.fs.{FileStatus, Path}
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito.when
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COLUMN
import org.scalatest.matchers.should.Matchers
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.FlintSuite
import org.apache.spark.sql.{Column, DataFrame, Row}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Literal, Predicate}
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.types._

class FlintSparkSkippingFileIndexSuite extends FlintSuite with Matchers {

/** Test source partition data. */
private val partition1 = "partition-1" -> Seq("file-1", "file-2")
private val partition2 = "partition-2" -> Seq("file-3")

/** Test index data schema. */
private val schema = Map((FILE_PATH_COLUMN, StringType), ("year", IntegerType))

test("should keep files returned from index") {
assertFlintFileIndex()
.withSourceFiles(Map(partition1))
.withIndexData(schema, Seq(Row("file-1", 2023), Row("file-2", 2022)))
.withIndexFilter(col("year") === 2023)
.shouldScanSourceFiles(Map("partition-1" -> Seq("file-1")))
}

test("should keep files of multiple partitions returned from index") {
assertFlintFileIndex()
.withSourceFiles(Map(partition1, partition2))
.withIndexData(schema, Seq(Row("file-1", 2023), Row("file-2", 2022), Row("file-3", 2023)))
.withIndexFilter(col("year") === 2023)
.shouldScanSourceFiles(Map("partition-1" -> Seq("file-1"), "partition-2" -> Seq("file-3")))
}

test("should skip unknown source files by default") {
assertFlintFileIndex()
.withSourceFiles(Map(partition1))
.withIndexData(
schema,
Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet
)
.withIndexFilter(col("year") === 2023)
.shouldScanSourceFiles(Map("partition-1" -> Seq("file-1")))
}

test("should not skip unknown source files in hybrid-scan mode") {
withHybridScanEnabled {
assertFlintFileIndex()
.withSourceFiles(Map(partition1))
.withIndexData(
schema,
Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet
)
.withIndexFilter(col("year") === 2023)
.shouldScanSourceFiles(Map("partition-1" -> Seq("file-1", "file-2")))
}
}

test("should not skip unknown source files of multiple partitions in hybrid-scan mode") {
withHybridScanEnabled {
assertFlintFileIndex()
.withSourceFiles(Map(partition1, partition2))
.withIndexData(
schema,
Seq(Row("file-1", 2023)) // file-2 is not refreshed to index yet
)
.withIndexFilter(col("year") === 2023)
.shouldScanSourceFiles(
Map("partition-1" -> Seq("file-1", "file-2"), "partition-2" -> Seq("file-3")))
}
}

private def assertFlintFileIndex(): AssertionHelper = {
new AssertionHelper
}

private class AssertionHelper {
private val baseFileIndex = mock[FileIndex]
private var indexScan: DataFrame = _
private var indexFilter: Predicate = _

def withSourceFiles(partitions: Map[String, Seq[String]]): AssertionHelper = {
when(baseFileIndex.listFiles(any(), any()))
.thenReturn(mockPartitions(partitions))
this
}

def withIndexData(columns: Map[String, DataType], data: Seq[Row]): AssertionHelper = {
val schema = StructType(columns.map { case (colName, colType) =>
StructField(colName, colType, nullable = false)
}.toSeq)
indexScan = spark.createDataFrame(spark.sparkContext.parallelize(data), schema)
this
}

def withIndexFilter(pred: Column): AssertionHelper = {
indexFilter = pred.expr.asInstanceOf[Predicate]
this
}

def shouldScanSourceFiles(partitions: Map[String, Seq[String]]): Unit = {
val fileIndex = FlintSparkSkippingFileIndex(baseFileIndex, indexScan, indexFilter)
fileIndex.listFiles(Seq.empty, Seq.empty) shouldBe mockPartitions(partitions)
}

private def mockPartitions(partitions: Map[String, Seq[String]]): Seq[PartitionDirectory] = {
partitions.map { case (partitionName, filePaths) =>
val files = filePaths.map(path => new FileStatus(0, false, 0, 0, 0, new Path(path)))
PartitionDirectory(InternalRow(Literal(partitionName)), files)
}.toSeq
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ class FlintSparkSkippingIndexSuite
hasIndexFilter(col("MinMax_age_0") <= 25 && col("MinMax_age_1") >= 25))
}

test("should rewrite applicable query to scan latest source files in hybrid scan mode") {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("month")
.create()
flint.refreshIndex(testIndex, FULL)

// Generate a new source file which is not in index data
sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=4)
| VALUES ('Hello', 35, 'Vancouver')
| """.stripMargin)

withHybridScanEnabled {
val query = sql(s"""
| SELECT address
| FROM $testTable
| WHERE month = 4
|""".stripMargin)

checkAnswer(query, Seq(Row("Seattle"), Row("Vancouver")))
}
}

test("should return empty if describe index not exist") {
flint.describeIndex("non-exist") shouldBe empty
}
Expand Down Expand Up @@ -333,7 +359,7 @@ class FlintSparkSkippingIndexSuite
// Custom matcher to check if FlintSparkSkippingFileIndex has expected filter condition
def hasIndexFilter(expect: Column): Matcher[FlintSparkSkippingFileIndex] = {
Matcher { (fileIndex: FlintSparkSkippingFileIndex) =>
val plan = fileIndex.filterByIndex.queryExecution.logical
val plan = fileIndex.indexScan.queryExecution.logical
val hasExpectedFilter = plan.find {
case Filter(actual, _) =>
actual.semanticEquals(expect.expr)
Expand Down

0 comments on commit 5da4f0a

Please sign in to comment.