Skip to content

Commit

Permalink
Implement describe index and IT
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jun 27, 2023
1 parent 8d5c964 commit 719a953
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,30 +5,45 @@

package org.opensearch.flint.spark.sql

import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIndexName
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{DescribeSkippingIndexStatementContext, DropSkippingIndexStatementContext}

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.AttributeReference
import org.apache.spark.sql.catalyst.plans.logical.Command
import org.apache.spark.sql.types.StringType

/**
* Flint Spark AST builder that builds Spark command for Flint index statement.
*/
class FlintSparkSqlAstBuilder extends FlintSparkSqlExtensionsBaseVisitor[Command] {

override def visitDescribeSkippingIndexStatement(
ctx: DescribeSkippingIndexStatementContext): Command =
FlintSparkSqlCommand { flint =>
Seq.empty
ctx: DescribeSkippingIndexStatementContext): Command = {
val outputSchema = Seq(
AttributeReference("indexed_col_name", StringType, nullable = false)(),
AttributeReference("data_type", StringType, nullable = false)())

FlintSparkSqlCommand(outputSchema) { flint =>
val indexName = getSkippingIndexName(ctx.tableName.getText)
flint
.describeIndex(indexName)
.map { case index: FlintSparkSkippingIndex =>
index.indexedColumns.map(strategy => Row(strategy.columnName, strategy.columnType))
}
.getOrElse(Seq.empty)
}
}

override def visitDropSkippingIndexStatement(ctx: DropSkippingIndexStatementContext): Command =
FlintSparkSqlCommand { flint =>
FlintSparkSqlCommand() { flint =>
val tableName = ctx.tableName.getText // TODO: handle schema name
val indexName = getSkippingIndexName(tableName)
flint.deleteIndex(indexName)
Seq.empty
}

override def aggregateResult(aggregate: Command, nextResult: Command): Command =
if (nextResult != null) nextResult else aggregate;
if (nextResult != null) nextResult else aggregate
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package org.opensearch.flint.spark.sql
import org.opensearch.flint.spark.FlintSpark

import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.execution.command.LeafRunnableCommand

/**
Expand All @@ -19,7 +20,12 @@ import org.apache.spark.sql.execution.command.LeafRunnableCommand
* @param block
* code block that triggers Flint core API
*/
case class FlintSparkSqlCommand(block: FlintSpark => Seq[Row]) extends LeafRunnableCommand {
case class FlintSparkSqlCommand(override val output: Seq[Attribute] = Seq.empty)(
block: FlintSpark => Seq[Row])
extends LeafRunnableCommand {

override def run(sparkSession: SparkSession): Seq[Row] = block(new FlintSpark(sparkSession))

// Lazy arguments are required to specify here
override protected def otherCopyArgs: Seq[AnyRef] = block :: Nil
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.getSkippingIn
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper

import org.apache.spark.FlintSuite
import org.apache.spark.sql.QueryTest
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.flint.config.FlintSparkConf.{HOST_ENDPOINT, HOST_PORT}

class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite {
Expand Down Expand Up @@ -55,6 +55,7 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.addValueSet("name")
.create()
}

Expand All @@ -67,7 +68,14 @@ class FlintSparkSqlSuite extends QueryTest with FlintSuite with OpenSearchSuite
test("describe skipping index") {
val result = sql(s"DESC SKIPPING INDEX ON $testTable")

checkAnswer(result, Seq())
checkAnswer(result, Seq(Row("year", "int"), Row("name", "string")))
}

test("should return empty if no skipping index to describe") {
flint.deleteIndex(testIndex)

val result = sql(s"DESC SKIPPING INDEX ON $testTable")
checkAnswer(result, Seq.empty)
}

test("drop skipping index") {
Expand Down

0 comments on commit 719a953

Please sign in to comment.