Skip to content

Commit

Permalink
Merge branch 'feature/flint' into add-doc-id
Browse files Browse the repository at this point in the history
  • Loading branch information
dai-chen committed Jul 11, 2023
2 parents d7bff3e + 5da4f0a commit adeed60
Show file tree
Hide file tree
Showing 12 changed files with 310 additions and 37 deletions.
1 change: 1 addition & 0 deletions flint/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.flint.optimizer.enabled`: default is true.
- `spark.flint.index.hybridscan.enabled`: default is false.

#### Data Type Mapping

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,12 @@ statement
;

skippingIndexStatement
: dropSkippingIndexStatement
: describeSkippingIndexStatement
| dropSkippingIndexStatement
;

describeSkippingIndexStatement
: (DESC | DESCRIBE) SKIPPING INDEX ON tableName=multipartIdentifier
;

dropSkippingIndexStatement
Expand Down
2 changes: 2 additions & 0 deletions flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ SKIPPING : 'SKIPPING';

SEMICOLON: ';';

DESC: 'DESC';
DESCRIBE: 'DESCRIBE';
DOT: '.';
DROP: 'DROP';
INDEX: 'INDEX';
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,10 @@ object FlintSparkConf {
val OPTIMIZER_RULE_ENABLED = FlintConfig("spark.flint.optimizer.enabled")
.doc("Enable Flint optimizer rule for query rewrite with Flint index")
.createWithDefault("true")

val HYBRID_SCAN_ENABLED = FlintConfig("spark.flint.index.hybridscan.enabled")
.doc("Enable hybrid scan to include latest source data not refreshed to index yet")
.createWithDefault("false")
}

/**
Expand All @@ -114,6 +118,8 @@ class FlintSparkConf(properties: JMap[String, String]) extends Serializable {

def isOptimizerEnabled: Boolean = OPTIMIZER_RULE_ENABLED.readFrom(reader).toBoolean

def isHybridScanEnabled: Boolean = HYBRID_SCAN_ENABLED.readFrom(reader).toBoolean

/**
* spark.sql.session.timeZone
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
package org.opensearch.flint.spark.skipping

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.{And, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
Expand Down Expand Up @@ -38,7 +38,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
val index = flint.describeIndex(indexName)
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val indexPred = rewriteToIndexPredicate(skippingIndex, condition)
val indexFilter = rewriteToIndexFilter(skippingIndex, condition)

/*
* Replace original file index with Flint skipping file index:
Expand All @@ -47,9 +47,9 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
* |- HadoopFsRelation
* |- FileIndex <== replaced with FlintSkippingFileIndex
*/
if (indexPred.isDefined) {
val filterByIndex = buildFilterIndexQuery(skippingIndex, indexPred.get)
val fileIndex = new FlintSparkSkippingFileIndex(location, filterByIndex)
if (indexFilter.isDefined) {
val indexScan = buildIndexScan(skippingIndex)
val fileIndex = FlintSparkSkippingFileIndex(location, indexScan, indexFilter.get)
val indexRelation = baseRelation.copy(location = fileIndex)(baseRelation.sparkSession)
filter.copy(child = relation.copy(relation = indexRelation))
} else {
Expand All @@ -60,7 +60,7 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
}
}

private def rewriteToIndexPredicate(
private def rewriteToIndexFilter(
index: FlintSparkSkippingIndex,
condition: Predicate): Option[Predicate] = {

Expand All @@ -71,15 +71,9 @@ class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan]
.reduceOption(And(_, _))
}

private def buildFilterIndexQuery(
index: FlintSparkSkippingIndex,
rewrittenPredicate: Predicate): DataFrame = {

// Get file list based on the rewritten predicates on index data
private def buildIndexScan(index: FlintSparkSkippingIndex): DataFrame = {
flint.spark.read
.format(FLINT_DATASOURCE)
.load(index.name())
.filter(new Column(rewrittenPredicate))
.select(FILE_PATH_COLUMN)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,33 +6,43 @@
package org.opensearch.flint.spark.skipping

import org.apache.hadoop.fs.{FileStatus, Path}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COLUMN

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{Expression, Predicate}
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.functions.isnull
import org.apache.spark.sql.types.StructType

/**
* File index that skips source files based on the selected files by Flint skipping index.
*
* @param baseFileIndex
* original file index
* @param filterByIndex
* pushed down filtering on index data
* @param indexScan
* query skipping index DF with pushed down filters
*/
case class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex: DataFrame)
case class FlintSparkSkippingFileIndex(
baseFileIndex: FileIndex,
indexScan: DataFrame,
indexFilter: Predicate)
extends FileIndex {

override def listFiles(
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[PartitionDirectory] = {

// TODO: make this listFile call only in hybrid scan mode
val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
val selectedFiles =
filterByIndex.collect
.map(_.getString(0))
.toSet
if (FlintSparkConf().isHybridScanEnabled) {
selectFilesFromIndexAndSource(partitions)
} else {
selectFilesFromIndexOnly()
}

val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
// Keep partition files present in selected file list above
partitions
.map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f))))
.filter(p => p.files.nonEmpty)
Expand All @@ -48,6 +58,44 @@ case class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex:

override def partitionSchema: StructType = baseFileIndex.partitionSchema

/*
* Left join source partitions and index data to keep unknown source files:
* Express the logic in SQL:
* SELECT left.file_path
* FROM partitions AS left
* LEFT JOIN indexScan AS right
* ON left.file_path = right.file_path
* WHERE right.file_path IS NULL
* OR [indexFilter]
*/
private def selectFilesFromIndexAndSource(partitions: Seq[PartitionDirectory]): Set[String] = {
val sparkSession = indexScan.sparkSession
import sparkSession.implicits._

partitions
.flatMap(_.files.map(f => f.getPath.toUri.toString))
.toDF(FILE_PATH_COLUMN)
.join(indexScan, Seq(FILE_PATH_COLUMN), "left")
.filter(isnull(indexScan(FILE_PATH_COLUMN)) || new Column(indexFilter))
.select(FILE_PATH_COLUMN)
.collect()
.map(_.getString(0))
.toSet
}

/*
* Consider file paths in index data alone. In this case, index filter can be pushed down
* to index store.
*/
private def selectFilesFromIndexOnly(): Set[String] = {
indexScan
.filter(new Column(indexFilter))
.select(FILE_PATH_COLUMN)
.collect
.map(_.getString(0))
.toSet
}

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,47 @@

package org.opensearch.flint.spark.sql

import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.DropSkippingIndexStatementContext
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{DescribeSkippingIndexStatementContext, DropSkippingIndexStatementContext}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.Command
import org.apache.spark.sql.types.StringType

/**
* Flint Spark AST builder that builds Spark command for Flint index statement.
*/
class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] {

override def visitDropSkippingIndexStatement(
ctx: DropSkippingIndexStatementContext): Command = {
FlintSparkSqlCommand { flint =>
override def visitDescribeSkippingIndexStatement(
ctx: DescribeSkippingIndexStatementContext): Command = {
val outputSchema = Seq(
AttributeReference("indexed_col_name", StringType, nullable = false)(),
AttributeReference("data_type", StringType, nullable = false)(),
AttributeReference("skip_type", StringType, nullable = false)())

FlintSparkSqlCommand(outputSchema) { flint =>
val indexName = getSkippingIndexName(ctx.tableName.getText)
flint
.describeIndex(indexName)
.map { case index: FlintSparkSkippingIndex =>
index.indexedColumns.map(strategy =>
Row(strategy.columnName, strategy.columnType, strategy.kind.toString))
}
.getOrElse(Seq.empty)
}
}

override def visitDropSkippingIndexStatement(ctx: DropSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val tableName = ctx.tableName.getText // TODO: handle schema name
val indexName = getSkippingIndexName(tableName)
flint.deleteIndex(indexName)
Seq.empty
}
}

override def aggregateResult(aggregate: Command, nextResult: Command): Command =
if (nextResult != null) nextResult else aggregate;
if (nextResult != null) nextResult else aggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.sql
import org.opensearch.flint.spark.FlintSpark

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.command.LeafRunnableCommand

/**
Expand All @@ -19,7 +20,12 @@ import org.apache.spark.sql.execution.command.LeafRunnableCommand
* @param block
* code block that triggers Flint core API
*/
case class FlintSparkSqlCommand(block: FlintSpark => Seq[Row]) extends LeafRunnableCommand {
case class FlintSparkSqlCommand(override val output: Seq[Attribute] = Seq.empty)(
block: FlintSpark => Seq[Row])
extends LeafRunnableCommand {

override def run(sparkSession: SparkSession): Seq[Row] = block(new FlintSpark(sparkSession))

// Lazy arguments are required to specify here
override protected def otherCopyArgs: Seq[AnyRef] = block :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import org.opensearch.flint.spark.FlintSparkExtensions
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.flint.config.FlintConfigEntry
import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.test.SharedSparkSession

Expand All @@ -34,4 +35,13 @@ trait FlintSuite extends SharedSparkSession {
protected def setFlintSparkConf[T](config: FlintConfigEntry[T], value: Any): Unit = {
spark.conf.set(config.key, value.toString)
}

protected def withHybridScanEnabled(block: => Unit): Unit = {
setFlintSparkConf(HYBRID_SCAN_ENABLED, "true")
try {
block
} finally {
setFlintSparkConf(HYBRID_SCAN_ENABLED, "false")
}
}
}
Loading

0 comments on commit adeed60

Please sign in to comment.