From fa73555ecc0d54142f4ead36c8e2df9652297d88 Mon Sep 17 00:00:00 2001 From: Anish Shrigondekar Date: Fri, 5 Jul 2024 14:16:29 -0700 Subject: [PATCH] Misc fix --- .../v2/state/StateDataSource.scala | 38 +++++++++++++------ .../v2/state/StateDataSourceReadSuite.scala | 19 ++++++++++ 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala index f1c186e324e02..e096ff73b7b0f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSource.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.{RuntimeConfig, SparkSession} import org.apache.spark.sql.catalyst.DataSourceOptions import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform -import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues +import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.{JoinSideValues, STATE_VAR_NAME} import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions.JoinSideValues.JoinSideValues -import org.apache.spark.sql.execution.datasources.v2.state.metadata.StateMetadataPartitionReader +import org.apache.spark.sql.execution.datasources.v2.state.metadata.{StateMetadataPartitionReader, StateMetadataTableEntry} import org.apache.spark.sql.execution.streaming.{CommitLog, OffsetSeqLog, OffsetSeqMetadata} import org.apache.spark.sql.execution.streaming.StreamingCheckpointConstants.{DIR_NAME_COMMITS, DIR_NAME_OFFSETS, DIR_NAME_STATE} import org.apache.spark.sql.execution.streaming.StreamingSymmetricHashJoinHelper.{LeftSide, RightSide} @@ -54,6 +54,30 @@ class StateDataSource extends TableProvider with DataSourceRegister { private lazy val TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME = "transformWithStateExec" + private def runStateVarChecks( + sourceOptions: StateSourceOptions, + stateStoreMetadata: Array[StateMetadataTableEntry]): Unit = { + // Perform checks for transformWithState operator in case state variable name is provided + require(stateStoreMetadata.size == 1) + val opMetadata = stateStoreMetadata.head + // if we are trying to query state source with state variable name, then the operator + // should be transformWithState + if (opMetadata.operatorName != TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME) { + val errorMsg = "Providing state variable names is only supported with the " + + s"transformWithState operator. Found operator=${opMetadata.operatorName}. " + + s"Please remove this option and re-run the query." + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, errorMsg) + } + + // if the operator is transformWithState, but the operator properties are empty, then + // the user has not defined any state variables for the operator + val operatorProperties = opMetadata.operatorProperties + if (operatorProperties.isEmpty) { + throw StateDataSourceErrors.invalidOptionValue(STATE_VAR_NAME, + "No state variable names are defined for the transformWithState operator") + } + } + override def getTable( schema: StructType, partitioning: Array[Transform], @@ -71,15 +95,7 @@ class StateDataSource extends TableProvider with DataSourceRegister { } if (sourceOptions.stateVarName.isDefined) { - // Perform checks for transformWithState operator in case state variable name is provided - allStateStoreMetadata.foreach { entry => - if (entry.operatorName != TRANSFORM_WITH_STATE_OPERATOR_SHORT_NAME) { - val errorMsg = "Providing state variable names is only supported with the " + - s"transformWithState operator. Found operator=${entry.operatorName}" - throw StateDataSourceErrors.invalidOptionValue(sourceOptions.stateVarName.get, - errorMsg) - } - } + runStateVarChecks(sourceOptions, stateStoreMetadata) } new StateTable(session, schema, sourceOptions, stateConf, stateStoreMetadata) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala index e6cdd0dce9efa..f3050f6bc9248 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceReadSuite.scala @@ -268,6 +268,25 @@ class StateDataSourceNegativeTestSuite extends StateDataSourceTestBase { "message" -> s"value should be less than or equal to $endBatchId")) } } + + test("ERROR: trying to specify state variable name with " + + s"non-transformWithState operator") { + withTempDir { tempDir => + runDropDuplicatesQuery(tempDir.getAbsolutePath) + + val exc = intercept[StateDataSourceInvalidOptionValue] { + spark.read.format("statestore") + // trick to bypass getting the last committed batch before validating operator ID + .option(StateSourceOptions.BATCH_ID, 0) + .option(StateSourceOptions.STATE_VAR_NAME, "test") + .load(tempDir.getAbsolutePath) + } + checkError(exc, "STDS_INVALID_OPTION_VALUE.WITH_MESSAGE", Some("42616"), + Map("optionName" -> StateSourceOptions.STATE_VAR_NAME, + "message" -> ".*"), + matchPVals = true) + } + } } /**