Skip to content

Commit

Permalink
Support auto refresh property in WITH clause
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jul 5, 2023
1 parent 8657832 commit 23c87a3
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ skippingIndexStatement
createSkippingIndexStatement
: CREATE SKIPPING INDEX ON tableName=multipartIdentifier
LEFT_PAREN indexColTypeList RIGHT_PAREN
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

refreshSkippingIndexStatement
Expand Down
46 changes: 46 additions & 0 deletions flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,31 @@ grammar SparkSqlBase;
}


propertyList
: property (COMMA property)*
;

property
: key=propertyKey (EQ? value=propertyValue)?
;

propertyKey
: identifier (DOT identifier)*
| STRING
;

propertyValue
: INTEGER_VALUE
| DECIMAL_VALUE
| booleanValue
| STRING
;

booleanValue
: TRUE | FALSE
;


multipartIdentifier
: parts+=identifier (DOT parts+=identifier)*
;
Expand Down Expand Up @@ -120,17 +145,33 @@ RIGHT_PAREN: ')';
COMMA: ',';
DOT: '.';


CREATE: 'CREATE';
DESC: 'DESC';
DESCRIBE: 'DESCRIBE';
DROP: 'DROP';
FALSE: 'FALSE';
INDEX: 'INDEX';
ON: 'ON';
PARTITION: 'PARTITION';
REFRESH: 'REFRESH';
STRING: 'STRING';
TRUE: 'TRUE';
WITH: 'WITH';


EQ : '=' | '==';
MINUS: '-';


INTEGER_VALUE
: DIGIT+
;

DECIMAL_VALUE
: DECIMAL_DIGITS {isValidDecimal()}?
;

IDENTIFIER
: (LETTER | DIGIT | '_')+
;
Expand All @@ -139,6 +180,11 @@ BACKQUOTED_IDENTIFIER
: '`' ( ~'`' | '``' )* '`'
;

fragment DECIMAL_DIGITS
: DIGIT+ '.' DIGIT*
| '.' DIGIT+
;

fragment DIGIT
: [0-9]
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class FlintSpark(val spark: SparkSession) {
val job = spark.readStream
.table(tableName)
.writeStream
.queryName(indexName)
.outputMode(Append())
.foreachBatch { (batchDF: DataFrame, _: Long) =>
writeFlintIndex(batchDF)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind._
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateSkippingIndexStatementContext, DescribeSkippingIndexStatementContext, DropSkippingIndexStatementContext, RefreshSkippingIndexStatementContext}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKind.{MIN_MAX, PARTITION, VALUE_SET}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser._

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
Expand All @@ -25,21 +25,27 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
override def visitCreateSkippingIndexStatement(
ctx: CreateSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
// Create skipping index
val indexBuilder = flint
.skippingIndex()
.onTable(ctx.tableName.getText)

ctx.indexColTypeList().indexColType().forEach { colTypeCtx =>
val colName = colTypeCtx.identifier().getText
val skipType = SkippingKind.withName(colTypeCtx.skipType.getText)

skipType match {
case PARTITION => indexBuilder.addPartitions(colName)
case VALUE_SET => indexBuilder.addValueSet(colName)
case MIN_MAX => indexBuilder.addMinMax(colName)
}
}
indexBuilder.create()

// Trigger auto refresh if enabled
if (isAutoRefreshEnabled(ctx.propertyList())) {
val indexName = getSkippingIndexName(ctx.tableName.getText)
flint.refreshIndex(indexName, RefreshMode.INCREMENTAL)
}
Seq.empty
}

Expand Down Expand Up @@ -78,6 +84,21 @@ class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command
Seq.empty
}

private def isAutoRefreshEnabled(ctx: PropertyListContext): Boolean = {
if (ctx == null) {
false
} else {
ctx
.property()
.forEach(p => {
if (p.key.getText == "auto_refresh") {
return p.value.getText.toBoolean
}
})
false
}
}

override def aggregateResult(aggregate: Command, nextResult: Command): Command =
if (nextResult != null) nextResult else aggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkSkippingIndexSuite
class FlintSparkSkippingIndexITSuite
extends QueryTest
with FlintSuite
with OpenSearchSuite
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,13 @@ import org.apache.spark.FlintSuite
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkSqlITSuite extends QueryTest with FlintSuite with OpenSearchSuite {
class FlintSparkSqlITSuite
extends QueryTest
with FlintSuite
with OpenSearchSuite
with StreamTest {

/** Flint Spark high level API for assertion */
private lazy val flint: FlintSpark = new FlintSpark(spark)
Expand Down Expand Up @@ -73,6 +78,36 @@ class FlintSparkSqlITSuite extends QueryTest with FlintSuite with OpenSearchSuit
protected override def afterEach(): Unit = {
super.afterEach()
flint.deleteIndex(testIndex)

// Stop all streaming jobs if any
spark.streams.active.foreach { job =>
job.stop()
job.awaitTermination()
}
}

test("create skipping index with auto refresh") {
flint.deleteIndex(testIndex)
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
| year PARTITION,
| name VALUE_SET,
| age MIN_MAX
| )
| WITH (auto_refresh = true)
| """.stripMargin)

// Wait for streaming job complete current micro batch
val job = spark.streams.active.find(_.name == testIndex)
job shouldBe defined
failAfter(streamingTimeout) {
job.get.processAllAvailable()
}

val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex)
flint.describeIndex(testIndex) shouldBe defined
indexData.count() shouldBe 1
}

test("create skipping index with manual refresh") {
Expand All @@ -86,7 +121,7 @@ class FlintSparkSqlITSuite extends QueryTest with FlintSuite with OpenSearchSuit
| )
| """.stripMargin)

val indexData = spark.read.format(FLINT_DATASOURCE).options(openSearchOptions).load(testIndex)
val indexData = spark.read.format(FLINT_DATASOURCE).load(testIndex)

flint.describeIndex(testIndex) shouldBe defined
indexData.count() shouldBe 0
Expand Down

0 comments on commit 23c87a3

Please sign in to comment.