Skip to content

Commit

Permalink
Generate doc ID in build index job for idempotency (#1803)
Browse files Browse the repository at this point in the history
* Add doc id to DF by Spark built-in sha1 function

Signed-off-by: Chen Dai <daichen@amazon.com>

* Add UT

Signed-off-by: Chen Dai <daichen@amazon.com>

* Fix broken IT

Signed-off-by: Chen Dai <daichen@amazon.com>

* Change id column name to avoid conflict with user data

Signed-off-by: Chen Dai <daichen@amazon.com>

---------

Signed-off-by: Chen Dai <daichen@amazon.com>
  • Loading branch information
dai-chen committed Jul 11, 2023
1 parent 5da4f0a commit b9eb0ea
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 9 deletions.
3 changes: 2 additions & 1 deletion flint/build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,8 @@ lazy val flintCore = (project in file("flint-core"))
scalaVersion := scala212,
libraryDependencies ++= Seq(
"org.opensearch.client" % "opensearch-rest-client" % opensearchVersion,
"org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion,
"org.opensearch.client" % "opensearch-rest-high-level-client" % opensearchVersion
exclude("org.apache.logging.log4j", "log4j-api"),
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude("com.fasterxml.jackson.core", "jackson-databind") ))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@

package org.opensearch.flint.spark

import scala.collection.JavaConverters._

import org.json4s.{Formats, JArray, NoTypeHints}
import org.json4s.native.JsonMethods.parse
import org.json4s.native.Serialization
import org.opensearch.flint.core.{FlintClient, FlintClientBuilder}
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSpark._
import org.opensearch.flint.spark.FlintSpark.RefreshMode.{FULL, INCREMENTAL, RefreshMode}
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.{FlintSparkSkippingIndex, FlintSparkSkippingStrategy}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.SKIPPING_INDEX_TYPE
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.{SkippingKind, SkippingKindSerializer}
Expand All @@ -25,15 +28,24 @@ import org.apache.spark.sql.SaveMode._
import org.apache.spark.sql.catalog.Column
import org.apache.spark.sql.flint.FlintDataSourceV2.FLINT_DATASOURCE
import org.apache.spark.sql.flint.config.FlintSparkConf
import org.apache.spark.sql.flint.config.FlintSparkConf.{DOC_ID_COLUMN_NAME, IGNORE_DOC_ID_COLUMN}
import org.apache.spark.sql.streaming.OutputMode.Append

/**
* Flint Spark integration API entrypoint.
*/
class FlintSpark(val spark: SparkSession) {

/** Flint spark configuration */
private val flintSparkConf: FlintSparkConf =
FlintSparkConf(
Map(
DOC_ID_COLUMN_NAME.key -> ID_COLUMN,
IGNORE_DOC_ID_COLUMN.key -> "true"
).asJava)

/** Flint client for low-level index operation */
private val flintClient: FlintClient = FlintClientBuilder.build(FlintSparkConf().flintOptions())
private val flintClient: FlintClient = FlintClientBuilder.build(flintSparkConf.flintOptions())

/** Required by json4s parse function */
implicit val formats: Formats = Serialization.formats(NoTypeHints) + SkippingKindSerializer
Expand Down Expand Up @@ -254,8 +266,7 @@ object FlintSpark {
require(tableName.nonEmpty, "table name cannot be empty")

val col = findColumn(colName)
addIndexedColumn(
ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType))
addIndexedColumn(ValueSetSkippingStrategy(columnName = col.name, columnType = col.dataType))
this
}

Expand All @@ -269,9 +280,8 @@ object FlintSpark {
*/
def addMinMax(colName: String): IndexBuilder = {
val col = findColumn(colName)
indexedColumns = indexedColumns :+ MinMaxSkippingStrategy(
columnName = col.name,
columnType = col.dataType)
indexedColumns =
indexedColumns :+ MinMaxSkippingStrategy(columnName = col.name, columnType = col.dataType)
this
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,11 @@ trait FlintSparkIndex {
*/
def build(df: DataFrame): DataFrame
}

object FlintSparkIndex {

/**
* ID column name.
*/
val ID_COLUMN: String = "__id__"
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,14 @@ import org.json4s.native.JsonMethods._
import org.json4s.native.Serialization
import org.opensearch.flint.core.metadata.FlintMetadata
import org.opensearch.flint.spark.FlintSparkIndex
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.{getSkippingIndexName, FILE_PATH_COLUMN, SKIPPING_INDEX_TYPE}
import org.opensearch.flint.spark.skipping.FlintSparkSkippingStrategy.SkippingKindSerializer

import org.apache.spark.sql.{Column, DataFrame}
import org.apache.spark.sql.catalyst.dsl.expressions.DslExpression
import org.apache.spark.sql.flint.datatype.FlintDataType
import org.apache.spark.sql.functions.input_file_name
import org.apache.spark.sql.functions.{col, input_file_name, sha1}
import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType}

/**
Expand All @@ -30,6 +31,8 @@ class FlintSparkSkippingIndex(
val indexedColumns: Seq[FlintSparkSkippingStrategy])
extends FlintSparkIndex {

require(indexedColumns.nonEmpty, "indexed columns must not be empty")

/** Required by json4s write function */
implicit val formats: Formats = Serialization.formats(NoTypeHints) + SkippingKindSerializer

Expand Down Expand Up @@ -64,6 +67,7 @@ class FlintSparkSkippingIndex(

df.groupBy(input_file_name().as(FILE_PATH_COLUMN))
.agg(namedAggFuncs.head, namedAggFuncs.tail: _*)
.withColumn(ID_COLUMN, sha1(col(FILE_PATH_COLUMN)))
}

private def getMetaInfo: String = {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.skipping

import org.mockito.Mockito.when
import org.opensearch.flint.spark.FlintSparkIndex.ID_COLUMN
import org.opensearch.flint.spark.skipping.FlintSparkSkippingIndex.FILE_PATH_COLUMN
import org.scalatest.matchers.must.Matchers.contain
import org.scalatest.matchers.should.Matchers.convertToAnyShouldWrapper
import org.scalatestplus.mockito.MockitoSugar.mock

import org.apache.spark.FlintSuite
import org.apache.spark.sql.catalyst.expressions.aggregate.CollectSet
import org.apache.spark.sql.functions.col

class FlintSparkSkippingIndexSuite extends FlintSuite {

test("get skipping index name") {
val index = new FlintSparkSkippingIndex("test", Seq(mock[FlintSparkSkippingStrategy]))
index.name() shouldBe "flint_test_skipping_index"
}

test("can build index building job with unique ID column") {
val indexCol = mock[FlintSparkSkippingStrategy]
when(indexCol.outputSchema()).thenReturn(Map("name" -> "string"))
when(indexCol.getAggregators).thenReturn(Seq(CollectSet(col("name").expr)))
val index = new FlintSparkSkippingIndex("test", Seq(indexCol))

val df = spark.createDataFrame(Seq(("hello", 20))).toDF("name", "age")
val indexDf = index.build(df)
indexDf.schema.fieldNames should contain only ("name", FILE_PATH_COLUMN, ID_COLUMN)
}

test("should fail if no indexed column given") {
assertThrows[IllegalArgumentException] {
new FlintSparkSkippingIndex("test", Seq.empty)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import org.apache.spark.sql.flint.config.FlintSparkConf._
import org.apache.spark.sql.functions.col
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkSkippingIndexSuite
class FlintSparkSkippingIndexITSuite
extends QueryTest
with FlintSuite
with OpenSearchSuite
Expand Down Expand Up @@ -205,12 +205,14 @@ class FlintSparkSkippingIndexSuite
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.create()

assertThrows[IllegalStateException] {
flint
.skippingIndex()
.onTable(testTable)
.addPartitions("year")
.create()
}
}
Expand Down

0 comments on commit b9eb0ea

Please sign in to comment.