Skip to content

Commit

Permalink
Misc fix
Browse files Browse the repository at this point in the history
  • Loading branch information
anishshri-db authored and jingz-db committed Jul 22, 2024
1 parent 3752517 commit fa73555
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ import org.apache.spark.sql.{RuntimeConfig, SparkSession}
import org.apache.spark.sql.catalyst.DataSourceOptions
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME}
import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues
import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader
import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry}
import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata}
import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE}
import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide}
Expand All @@ -54,6 +54,30 @@ class StateDataSource extends TableProvider with DataSourceRegister {

private lazy val TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME = "transformWithStateExec"

private def runStateVarChecks(
sourceOptions: StateSourceOptions,
stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = {
// Perform checks for transformWithState operator in case state variable name is provided
require(stateStoreMetadata.size == 1)
val opMetadata = stateStoreMetadata.head
// if we are trying to query state source with state variable name, then the operator
// should be transformWithState
if (opMetadata.operatorName != TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME) {
val errorMsg = "Providing state variable names is only supported with the " +
s"transformWithState operator. Found operator=${opMetadata.operatorName}. " +
s"Please remove this option and re-run the query."
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, errorMsg)
}

// if the operator is transformWithState, but the operator properties are empty, then
// the user has not defined any state variables for the operator
val operatorProperties = opMetadata.operatorProperties
if (operatorProperties.isEmpty) {
throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME,
"No state variable names are defined for the transformWithState operator")
}
}

override def getTable(
schema: StructType,
partitioning: Array[Transform],
Expand All @@ -71,15 +95,7 @@ class StateDataSource extends TableProvider with DataSourceRegister {
}

if (sourceOptions.stateVarName.isDefined) {
// Perform checks for transformWithState operator in case state variable name is provided
allStateStoreMetadata.foreach { entry =>
if (entry.operatorName != TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME) {
val errorMsg = "Providing state variable names is only supported with the " +
s"transformWithState operator. Found operator=${entry.operatorName}"
throw StateDataSourceErrors.invalidOptionValue(sourceOptions.stateVarName.get,
errorMsg)
}
}
runStateVarChecks(sourceOptions, stateStoreMetadata)
}

new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase {
"message" -> s"value should be less than or equal to $endBatchId"))
}
}

test("ERROR: trying to specify state variable name with " +
s"non-transformWithState operator") {
withTempDir { tempDir =>
runDropDuplicatesQuery(tempDir.getAbsolutePath)

val exc = intercept[StateDataSourceInvalidOptionValue] {
spark.read.format("statestore")
// trick to bypass getting the last committed batch before validating operator ID
.option(StateSourceOptions.BATCH_ID, 0)
.option(StateSourceOptions.STATE_VAR_NAME, "test")
.load(tempDir.getAbsolutePath)
}
checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"),
Map("optionName" -> StateSourceOptions.STATE_VAR_NAME,
"message" -> ".*"),
matchPVals = true)
}
}
}

/**
Expand Down

0 comments on commit fa73555

Please sign in to comment.