Skip to content

Commit

Permalink
Add create and refresh index
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jun 28, 2023
1 parent becdf5f commit 0069384
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,33 @@ statement
;

skippingIndexStatement
: describeSkippingIndexStatement
: createSkippingIndexStatement
| refreshSkippingIndexStatement
| describeSkippingIndexStatement
| dropSkippingIndexStatement
;

createSkippingIndexStatement
: CREATE SKIPPING INDEX ON tableName=multipartIdentifier
LEFT_PAREN indexColTypeList RIGHT_PAREN
;

refreshSkippingIndexStatement
: REFRESH SKIPPING INDEX ON tableName=multipartIdentifier
;

describeSkippingIndexStatement
: (DESC | DESCRIBE) SKIPPING INDEX ON tableName=multipartIdentifier
;

dropSkippingIndexStatement
: DROP SKIPPING INDEX ON tableName=multipartIdentifier
;

indexColTypeList
: indexColType (COMMA indexColType)*
;

indexColType
: identifier skipType=(PARTITION | VALUE_SET | MIN_MAX)
;
16 changes: 13 additions & 3 deletions flint/flint-spark-integration/src/main/antlr4/SparkSqlBase.g4
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,30 @@ nonReserved

// Flint lexical tokens

SKIPPING : 'SKIPPING';
MIN_MAX: 'MIN_MAX';
SKIPPING: 'SKIPPING';
VALUE_SET: 'VALUE_SET';


// Spark lexical tokens

SEMICOLON: ';';

LEFT_PAREN: '(';
RIGHT_PAREN: ')';
COMMA: ',';
DOT: '.';

CREATE: 'CREATE';
DESC: 'DESC';
DESCRIBE: 'DESCRIBE';
DOT: '.';
DROP: 'DROP';
INDEX: 'INDEX';
MINUS: '-';
ON: 'ON';
PARTITION: 'PARTITION';
REFRESH: 'REFRESH';

MINUS: '-';

IDENTIFIER
: (LETTER | DIGIT | '_')+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@

package org.opensearch.flint.spark.sql

import java.util.Locale

import org.opensearch.flint.spark.FlintSpark.RefreshMode
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{DescribeSkippingIndexStatementContext, DropSkippingIndexStatementContext}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateSkippingIndexStatementContext, DescribeSkippingIndexStatementContext, DropSkippingIndexStatementContext, RefreshSkippingIndexStatementContext}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
Expand All @@ -19,6 +22,35 @@ import org.apache.spark.sql.types.StringType
*/
class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] {

override def visitCreateSkippingIndexStatement(
ctx: CreateSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val indexBuilder = flint
.skippingIndex()
.onTable(ctx.tableName.getText)

ctx.indexColTypeList().indexColType().forEach { colTypeCtx =>
val colName = colTypeCtx.identifier().getText
val skipType = colTypeCtx.skipType.getText.toLowerCase(Locale.ROOT)

skipType match {
case "partition" => indexBuilder.addPartitions(colName)
case "value_set" => indexBuilder.addValueSet(colName)
case "min_max" => indexBuilder.addMinMax(colName)
}
}
indexBuilder.create()
Seq.empty
}

override def visitRefreshSkippingIndexStatement(
ctx: RefreshSkippingIndexStatementContext): Command =
FlintSparkSqlCommand() { flint =>
val indexName = getSkippingIndexName(ctx.tableName.getText)
flint.refreshIndex(indexName, RefreshMode.FULL)
Seq.empty
}

override def visitDescribeSkippingIndexStatement(
ctx: DescribeSkippingIndexStatementContext): Command = {
val outputSchema = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,27 +9,32 @@ import scala.Option.empty

import org.opensearch.flint.OpenSearchSuite
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.scalatest.matchers.must.Matchers.defined
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT}
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT, REFRESH_POLICY}

class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite {

/** Flint Spark high level API for assertion */
private lazy val flint: FlintSpark = {
setFlintSparkConf(HOST_ENDPOINT, openSearchHost)
setFlintSparkConf(HOST_PORT, openSearchPort)
new FlintSpark(spark)
}
private lazy val flint: FlintSpark = new FlintSpark(spark)

/** Test table and index name */
private val testTable = "test"
private val testIndex = getSkippingIndexName(testTable)

override def beforeAll(): Unit = {
super.beforeAll()

// Configure for FlintSpark explicit created above and the one behind Flint SQL
setFlintSparkConf(HOST_ENDPOINT, openSearchHost)
setFlintSparkConf(HOST_PORT, openSearchPort)
setFlintSparkConf(REFRESH_POLICY, true)

// Create test table
sql(s"""
| CREATE TABLE $testTable
| (
Expand All @@ -46,6 +51,12 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite
| month INT
| )
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| PARTITION (year=2023, month=4)
| VALUES ('Hello', 30)
| """.stripMargin)
}

protected override def beforeEach(): Unit = {
Expand All @@ -64,6 +75,26 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite
flint.deleteIndex(testIndex)
}

test("create skipping index with manual refresh") {
flint.deleteIndex(testIndex)
sql(s"""
| CREATE SKIPPING INDEX ON $testTable
| (
| year PARTITION,
| name VALUE_SET,
| age MIN_MAX
| )
| """.stripMargin)

val indexData = spark.read.format(FLINT_DATASOURCE).options(openSearchOptions).load(testIndex)

flint.describeIndex(testIndex) shouldBe defined
indexData.count() shouldBe 0

sql(s"REFRESH SKIPPING INDEX ON $testTable")
indexData.count() shouldBe 1
}

test("describe skipping index") {
val result = sql(s"DESC SKIPPING INDEX ON $testTable")

Expand Down

0 comments on commit 0069384

Please sign in to comment.