Skip to content

Commit

Permalink
Query rewrite for partition skipping index (#1690)
Browse files Browse the repository at this point in the history
* Add skipping data file index

Signed-off-by: Chen Dai <daichen@amazon.com>

* Fix query plan integral check failure

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add comments for critical code path

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add check for filtering condition cannot be rewritten

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add logging

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add more comments and logging

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add Flint enabled config

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add more IT for query rewrite

Signed-off-by: Chen Dai <daichen@amazon.com>

* Refactor to use Flint data type mapping

Signed-off-by: Chen Dai <daichen@amazon.com>

* Polish doc and comments for PR review

Signed-off-by: Chen Dai <daichen@amazon.com>

* Addressed PR comments

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jun 5, 2023
1 parent 7dd57ce commit 35d5813
Show file tree
Hide file tree
Showing 12 changed files with 324 additions and 29 deletions.
1 change: 1 addition & 0 deletions flint/docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ In the index mapping, the `_meta` and `properties`field stores meta and schema i
- `spark.datasource.flint.write.refresh_policy`: default value is false. valid values [NONE(false),
IMMEDIATE(true), WAIT_UNTIL(wait_for)]
- `spark.datasource.flint.read.scroll_size`: default value is 100.
- `spark.flint.optimizer.enabled`: default is true.

#### API

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,10 @@ object FlintSparkConf {
val SCROLL_SIZE = FlintConfig("read.scroll_size")
.doc("scroll read size")
.createWithDefault("100")

val OPTIMIZER_RULE_ENABLED = FlintConfig("spark.flint.optimizer.enabled")
.doc("Enable Flint optimizer rule for query rewrite with Flint index")
.createWithDefault("true")
}

class FlintSparkConf(properties: JMap[String, String]) extends Serializable {
Expand All @@ -97,6 +101,8 @@ class FlintSparkConf(properties: JMap[String, String]) extends Serializable {
else throw new NoSuchElementException("index or path not found")
}

def isOptimizerEnabled: Boolean = OPTIMIZER_RULE_ENABLED.readFrom(reader).toBoolean

/**
* Helper class, create {@link FlintOptions}.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,18 @@

package org.opensearch.flint.spark

import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSessionExtensions

/**
* Flint Spark extension entrypoint.
*/
class FlintSparkExtensions extends (SparkSessionExtensions => Unit) {

override def apply(v1: SparkSessionExtensions): Unit = {}
override def apply(extensions: SparkSessionExtensions): Unit = {
extensions.injectOptimizerRule { spark =>
new FlintSparkOptimizer(spark)
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark

import scala.collection.JavaConverters._

import org.opensearch.flint.spark.skipping.ApplyFlintSparkSkippingIndex

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.flint.config.FlintSparkConf

/**
* Flint Spark optimizer that manages all Flint related optimizer rule.
* @param spark
* Spark session
*/
class FlintSparkOptimizer(spark: SparkSession) extends Rule[LogicalPlan] {

/** Flint Spark API */
private val flint: FlintSpark = new FlintSpark(spark)

/** Only one Flint optimizer rule for now. Need to estimate cost if more than one in future. */
private val rule = new ApplyFlintSparkSkippingIndex(flint)

override def apply(plan: LogicalPlan): LogicalPlan = {
if (isOptimizerEnabled) {
rule.apply(plan)
} else {
plan
}
}

private def isOptimizerEnabled: Boolean = {
val flintConf = new FlintSparkConf(spark.conf.getAll.asJava)
flintConf.isOptimizerEnabled
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.opensearch.flint.spark.FlintSpark
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.expressions.{And, Predicate}
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE

/**
* Flint Spark skipping index apply rule that rewrites applicable query's filtering condition and
* table scan operator to leverage additional skipping data structure and accelerate query by
* reducing data scanned significantly.
*
* @param flint
* Flint Spark API
*/
class ApplyFlintSparkSkippingIndex(flint: FlintSpark) extends Rule[LogicalPlan] {

override def apply(plan: LogicalPlan): LogicalPlan = plan transform {
case filter @ Filter( // TODO: abstract pattern match logic for different table support
condition: Predicate,
relation @ LogicalRelation(
baseRelation @ HadoopFsRelation(location, _, _, _, _, _),
_,
Some(table),
false)) if !location.isInstanceOf[FlintSparkSkippingFileIndex] =>

val indexName = getSkippingIndexName(table.identifier.table) // TODO: database name
val index = flint.describeIndex(indexName)
if (index.exists(_.kind == SKIPPING_INDEX_TYPE)) {
val skippingIndex = index.get.asInstanceOf[FlintSparkSkippingIndex]
val indexPred = rewriteToIndexPredicate(skippingIndex, condition)

/*
* Replace original file index with Flint skipping file index:
* Filter(a=b)
* |- LogicalRelation(A)
* |- HadoopFsRelation
* |- FileIndex <== replaced with FlintSkippingFileIndex
*/
if (indexPred.isDefined) {
val filterByIndex = buildFilterIndexQuery(skippingIndex, indexPred.get)
val fileIndex = new FlintSparkSkippingFileIndex(location, filterByIndex)
val indexRelation = baseRelation.copy(location = fileIndex)(baseRelation.sparkSession)
filter.copy(child = relation.copy(relation = indexRelation))
} else {
filter
}
} else {
filter
}
}

private def rewriteToIndexPredicate(
index: FlintSparkSkippingIndex,
condition: Predicate): Option[Predicate] = {

// TODO: currently only handle conjunction, namely the given condition is consist of
// one or more expression concatenated by AND only.
index.indexedColumns
.flatMap(index => index.rewritePredicate(condition))
.reduceOption(And(_, _))
}

private def buildFilterIndexQuery(
index: FlintSparkSkippingIndex,
rewrittenPredicate: Predicate): DataFrame = {

// Get file list based on the rewritten predicates on index data
flint.spark.read
.format(FLINT_DATASOURCE)
.load(index.name())
.filter(new Column(rewrittenPredicate))
.select(FILE_PATH_COLUMN)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.apache.hadoop.fs.{FileStatus, Path}

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution.datasources.{FileIndex, PartitionDirectory}
import org.apache.spark.sql.types.StructType

/**
* File index that skips source files based on the selected files by Flint skipping index.
*
* @param baseFileIndex
* original file index
* @param filterByIndex
* pushed down filtering on index data
*/
class FlintSparkSkippingFileIndex(baseFileIndex: FileIndex, filterByIndex: DataFrame)
extends FileIndex {

override def listFiles(
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression]): Seq[PartitionDirectory] = {

val selectedFiles =
filterByIndex.collect
.map(_.getString(0))
.toSet

// TODO: figure out if list file call can be avoided
val partitions = baseFileIndex.listFiles(partitionFilters, dataFilters)
partitions
.map(p => p.copy(files = p.files.filter(f => isFileNotSkipped(selectedFiles, f))))
.filter(p => p.files.nonEmpty)
}

override def rootPaths: Seq[Path] = baseFileIndex.rootPaths

override def inputFiles: Array[String] = baseFileIndex.inputFiles

override def refresh(): Unit = baseFileIndex.refresh()

override def sizeInBytes: Long = baseFileIndex.sizeInBytes

override def partitionSchema: StructType = baseFileIndex.partitionSchema

private def isFileNotSkipped(selectedFiles: Set[String], f: FileStatus) = {
selectedFiles.contains(f.getPath.toUri.toString)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,27 @@
package org.opensearch.flint.spark.skipping

import org.json4s._
import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

/**
* Flint skipping index in Spark.
*
* @param tableName
* source table name
*/
class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkSkippingStrategy])
class FlintSparkSkippingIndex(
tableName: String,
val indexedColumns: Seq[FlintSparkSkippingStrategy])
extends FlintSparkIndex {

/** Required by json4s write function */
Expand All @@ -30,15 +35,6 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS
/** Skipping index type */
override val kind: String = SKIPPING_INDEX_TYPE

/** Output schema of the skipping index */
private val outputSchema: Map[String, String] = {
val schema = indexedColumns
.flatMap(_.outputSchema().toList)
.toMap

schema + (FILE_PATH_COLUMN -> "keyword")
}

override def name(): String = {
getSkippingIndexName(tableName)
}
Expand Down Expand Up @@ -74,9 +70,20 @@ class FlintSparkSkippingIndex(tableName: String, indexedColumns: Seq[FlintSparkS
}

private def getSchema: String = {
Serialization.write(outputSchema.map { case (colName, colType) =>
colName -> ("type" -> colType)
})
val indexFieldTypes = indexedColumns.map { indexCol =>
val columnName = indexCol.columnName
// Data type INT from catalog is not recognized by Spark DataType.fromJson()
val columnType = if (indexCol.columnType == "int") "integer" else indexCol.columnType
val sparkType = DataType.fromJson("\"" + columnType + "\"")
StructField(columnName, sparkType, nullable = false)
}

val allFieldTypes =
indexFieldTypes :+ StructField(FILE_PATH_COLUMN, StringType, nullable = false)

// Convert StructType to {"properties": ...} and only need the properties value
val properties = FlintDataType.serialize(StructType(allFieldTypes))
compact(render(parse(properties) \ "properties"))
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.flint.spark.skipping

import org.apache.spark.sql.catalyst.expressions.Predicate
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateFunction

/**
Expand All @@ -18,9 +19,14 @@ trait FlintSparkSkippingStrategy {
val kind: String

/**
* Indexed column name and its Spark SQL type.
* Indexed column name.
*/
val columnName: String

/**
* Indexed column Spark SQL type.
*/
@transient
val columnType: String

/**
Expand All @@ -34,4 +40,15 @@ trait FlintSparkSkippingStrategy {
* aggregators that generate skipping data structure
*/
def getAggregators: Seq[AggregateFunction]

/**
* Rewrite a filtering condition on source table into a new predicate on index data based on
* current skipping strategy.
*
* @param predicate
* filtering condition on source table
* @return
* new filtering condition on index data or empty if index not applicable
*/
def rewritePredicate(predicate: Predicate): Option[Predicate]
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ package org.opensearch.flint.spark.skipping.partition
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal, Predicate}
import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateFunction, First}

/**
Expand All @@ -20,18 +22,18 @@ class PartitionSkippingStrategy(
extends FlintSparkSkippingStrategy {

override def outputSchema(): Map[String, String] = {
Map(columnName -> convertToFlintType(columnType))
Map(columnName -> columnType)
}

override def getAggregators: Seq[AggregateFunction] = {
Seq(First(new Column(columnName).expr, ignoreNulls = true))
}

// TODO: move this mapping info to single place
private def convertToFlintType(colType: String): String = {
colType match {
case "string" => "keyword"
case "int" => "integer"
}
override def rewritePredicate(predicate: Predicate): Option[Predicate] = {
// Column has same name in index data, so just rewrite to the same equation
predicate.collect {
case EqualTo(AttributeReference(`columnName`, _, _, _), value: Literal) =>
EqualTo(UnresolvedAttribute(columnName), value)
}.headOption
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

package org.apache.spark

import org.opensearch.flint.spark.FlintSparkExtensions

import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.flint.config.{FlintConfigEntry, FlintSparkConf}
Expand All @@ -22,6 +24,7 @@ trait FlintSuite extends SharedSparkSession {
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
.set("spark.sql.extensions", classOf[FlintSparkExtensions].getName)
conf
}

Expand Down
Loading

0 comments on commit 35d5813

Please sign in to comment.