From ae0e2ca98bee12f4e59d7b41c996c2359522e7ac Mon Sep 17 00:00:00 2001 From: Dongjoon Hyun Date: Wed, 13 Feb 2019 16:04:27 -0800 Subject: [PATCH 01/70] [SPARK-26865][SQL] DataSourceV2Strategy should push normalized filters ## What changes were proposed in this pull request? This PR aims to make `DataSourceV2Strategy` normalize filters like [FileSourceStrategy](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala#L150-L158) when it pushes them into `SupportsPushDownFilters.pushFilters`. ## How was this patch tested? Pass the Jenkins with the newly added test case. Closes #23770 from dongjoon-hyun/SPARK-26865. Authored-by: Dongjoon Hyun Signed-off-by: Dongjoon Hyun --- .../datasources/DataSourceStrategy.scala | 16 ++++++++++++++++ .../datasources/FileSourceStrategy.scala | 10 +--------- .../datasources/v2/DataSourceV2Strategy.scala | 7 +++++-- .../datasources/DataSourceStrategySuite.scala | 7 +++++++ 4 files changed, 29 insertions(+), 11 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index b5cf8c9515bfb..273cc3b19302d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -426,6 +426,22 @@ case class DataSourceStrategy(conf: SQLConf) extends Strategy with Logging with } object DataSourceStrategy { + /** + * The attribute name of predicate could be different than the one in schema in case of + * case insensitive, we should change them to match the one in schema, so we do not need to + * worry about case sensitivity anymore. + */ + protected[sql] def normalizeFilters( + filters: Seq[Expression], + attributes: Seq[AttributeReference]): Seq[Expression] = { + filters.filterNot(SubqueryExpression.hasSubquery).map { e => + e transform { + case a: AttributeReference => + a.withName(attributes.find(_.semanticEquals(a)).get.name) + } + } + } + /** * Tries to translate a Catalyst [[Expression]] into data source [[Filter]]. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala index 62ab5c80d47cf..970cbda6355e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategy.scala @@ -147,15 +147,7 @@ object FileSourceStrategy extends Strategy with Logging { // - filters that need to be evaluated again after the scan val filterSet = ExpressionSet(filters) - // The attribute name of predicate could be different than the one in schema in case of - // case insensitive, we should change them to match the one in schema, so we do not need to - // worry about case sensitivity anymore. - val normalizedFilters = filters.filterNot(SubqueryExpression.hasSubquery).map { e => - e transform { - case a: AttributeReference => - a.withName(l.output.find(_.semanticEquals(a)).get.name) - } - } + val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, l.output) val partitionColumns = l.resolve( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 40ac5cf402987..d6d17d6df7b1b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} @@ -104,10 +104,13 @@ object DataSourceV2Strategy extends Strategy { override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() + + val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, relation.output) + // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, filters) + val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, normalizedFilters) val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) logInfo( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala index f20aded169e44..2f5d5551c5df0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategySuite.scala @@ -219,6 +219,13 @@ class DataSourceStrategySuite extends PlanTest with SharedSQLContext { IsNotNull(attrInt))), None) } + test("SPARK-26865 DataSourceV2Strategy should push normalized filters") { + val attrInt = 'cint.int + assertResult(Seq(IsNotNull(attrInt))) { + DataSourceStrategy.normalizeFilters(Seq(IsNotNull(attrInt.withName("CiNt"))), Seq(attrInt)) + } + } + /** * Translate the given Catalyst [[Expression]] into data source [[sources.Filter]] * then verify against the given [[sources.Filter]]. From dc263480120d5ca59e464d25a9592721211b6abe Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 18 Feb 2019 13:16:28 +0800 Subject: [PATCH 02/70] [SPARK-26666][SQL] Support DSv2 overwrite and dynamic partition overwrite. ## What changes were proposed in this pull request? This adds two logical plans that implement the ReplaceData operation from the [logical plans SPIP](https://docs.google.com/document/d/1gYm5Ji2Mge3QBdOliFV5gSPTKlX4q1DCBXIkiyMv62A/edit?ts=5a987801#heading=h.m45webtwxf2d). These two plans will be used to implement Spark's `INSERT OVERWRITE` behavior for v2. Specific changes: * Add `SupportsTruncate`, `SupportsOverwrite`, and `SupportsDynamicOverwrite` to DSv2 write API * Add `OverwriteByExpression` and `OverwritePartitionsDynamic` plans (logical and physical) * Add new plans to DSv2 write validation rule `ResolveOutputRelation` * Refactor `WriteToDataSourceV2Exec` into trait used by all DSv2 write exec nodes ## How was this patch tested? * The v2 analysis suite has been updated to validate the new overwrite plans * The analysis suite for `OverwriteByExpression` checks that the delete expression is resolved using the table's columns * Existing tests validate that overwrite exec plan works * Updated existing v2 test because schema is used to validate overwrite Closes #23606 from rdblue/SPARK-26666-add-overwrite. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/catalyst/analysis/Analyzer.scala | 27 ++- .../plans/logical/basicLogicalOperators.scala | 69 ++++++- .../apache/spark/sql/internal/SQLConf.scala | 2 +- .../analysis/DataSourceV2AnalysisSuite.scala | 191 +++++++++++++----- .../v2/reader/SupportsPushDownFilters.java | 3 + .../v2/writer/SupportsDynamicOverwrite.java | 37 ++++ .../sources/v2/writer/SupportsOverwrite.java | 45 +++++ .../sources/v2/writer/SupportsTruncate.java | 32 +++ .../apache/spark/sql/DataFrameWriter.scala | 54 +++-- .../datasources/DataSourceStrategy.scala | 6 + .../v2/DataSourceV2Implicits.scala | 49 +++++ .../datasources/v2/DataSourceV2Relation.scala | 24 +-- .../datasources/v2/DataSourceV2Strategy.scala | 35 ++-- .../v2/WriteToDataSourceV2Exec.scala | 135 ++++++++++++- .../apache/spark/sql/sources/filters.scala | 26 ++- .../sql/sources/v2/DataSourceV2Suite.scala | 8 +- 16 files changed, 613 insertions(+), 130 deletions(-) create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index a84bb7653c527..0e95c10065676 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -978,6 +978,11 @@ class Analyzer( case a @ Aggregate(groupingExprs, aggExprs, appendColumns: AppendColumns) => a.mapExpressions(resolveExpressionTopDown(_, appendColumns)) + case o: OverwriteByExpression if !o.outputResolved => + // do not resolve expression attributes until the query attributes are resolved against the + // table by ResolveOutputRelation. that rule will alias the attributes to the table's names. + o + case q: LogicalPlan => logTrace(s"Attempting to resolve ${q.simpleString(SQLConf.get.maxToStringFields)}") q.mapExpressions(resolveExpressionTopDown(_, q)) @@ -2237,7 +2242,7 @@ class Analyzer( object ResolveOutputRelation extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperators { case append @ AppendData(table, query, isByName) - if table.resolved && query.resolved && !append.resolved => + if table.resolved && query.resolved && !append.outputResolved => val projection = resolveOutputColumns(table.name, table.output, query, isByName) if (projection != query) { @@ -2245,6 +2250,26 @@ class Analyzer( } else { append } + + case overwrite @ OverwriteByExpression(table, _, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } + + case overwrite @ OverwritePartitionsDynamic(table, query, isByName) + if table.resolved && query.resolved && !overwrite.outputResolved => + val projection = resolveOutputColumns(table.name, table.output, query, isByName) + + if (projection != query) { + overwrite.copy(query = projection) + } else { + overwrite + } } def resolveOutputColumns( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 639d68f4ecd76..f7f701cea51fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -365,16 +365,17 @@ case class Join( } /** - * Append data to an existing table. + * Base trait for DataSourceV2 write commands */ -case class AppendData( - table: NamedRelation, - query: LogicalPlan, - isByName: Boolean) extends LogicalPlan { +trait V2WriteCommand extends Command { + def table: NamedRelation + def query: LogicalPlan + override def children: Seq[LogicalPlan] = Seq(query) - override def output: Seq[Attribute] = Seq.empty - override lazy val resolved: Boolean = { + override lazy val resolved: Boolean = outputResolved + + def outputResolved: Boolean = { table.resolved && query.resolved && query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => @@ -386,16 +387,66 @@ case class AppendData( } } +/** + * Append data to an existing table. + */ +case class AppendData( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + object AppendData { def byName(table: NamedRelation, df: LogicalPlan): AppendData = { - new AppendData(table, df, true) + new AppendData(table, df, isByName = true) } def byPosition(table: NamedRelation, query: LogicalPlan): AppendData = { - new AppendData(table, query, false) + new AppendData(table, query, isByName = false) } } +/** + * Overwrite data matching a filter in an existing table. + */ +case class OverwriteByExpression( + table: NamedRelation, + deleteExpr: Expression, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand { + override lazy val resolved: Boolean = outputResolved && deleteExpr.resolved +} + +object OverwriteByExpression { + def byName( + table: NamedRelation, df: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, df, isByName = true) + } + + def byPosition( + table: NamedRelation, query: LogicalPlan, deleteExpr: Expression): OverwriteByExpression = { + OverwriteByExpression(table, deleteExpr, query, isByName = false) + } +} + +/** + * Dynamically overwrite partitions in an existing table. + */ +case class OverwritePartitionsDynamic( + table: NamedRelation, + query: LogicalPlan, + isByName: Boolean) extends V2WriteCommand + +object OverwritePartitionsDynamic { + def byName(table: NamedRelation, df: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, df, isByName = true) + } + + def byPosition(table: NamedRelation, query: LogicalPlan): OverwritePartitionsDynamic = { + OverwritePartitionsDynamic(table, query, isByName = false) + } +} + + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the * concrete implementations during analysis. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2e44ca5315c78..714dc6cda578d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1457,7 +1457,7 @@ object SQLConf { " register class names for which data source V2 write paths are disabled. Writes from these" + " sources will fall back to the V1 sources.") .stringConf - .createWithDefault("") + .createWithDefault("orc") val DISABLED_V2_STREAMING_WRITERS = buildConf("spark.sql.streaming.disabledV2Writers") .doc("A comma-separated list of fully qualified data source register class names for which" + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index 6c899b610ac5b..0c48548614266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -19,15 +19,92 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, UpCast} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project} import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} +class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + AppendData.byPosition(table, query) + } +} + +class V2OverwritePartitionsDynamicAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byName(table, query) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwritePartitionsDynamic.byPosition(table, query) + } +} + +class V2OverwriteByExpressionAnalysisSuite extends DataSourceV2AnalysisSuite { + override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byName(table, query, Literal(true)) + } + + override def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan = { + OverwriteByExpression.byPosition(table, query, Literal(true)) + } + + test("delete expression is resolved using table fields") { + val table = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val a = query.output.head + val b = query.output.last + val x = table.output.head + + val parsedPlan = OverwriteByExpression.byPosition(table, query, + LessThanOrEqual(UnresolvedAttribute(Seq("x")), Literal(15.0d))) + + val expectedPlan = OverwriteByExpression.byPosition(table, + Project(Seq( + Alias(Cast(a, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), + Alias(Cast(b, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), + query), + LessThanOrEqual( + AttributeReference("x", DoubleType, nullable = false)(x.exprId), + Literal(15.0d))) + + assertNotResolved(parsedPlan) + checkAnalysis(parsedPlan, expectedPlan) + assertResolved(expectedPlan) + } + + test("delete expression is not resolved using query fields") { + val xRequiredTable = TestRelation(StructType(Seq( + StructField("x", DoubleType, nullable = false), + StructField("y", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("a", DoubleType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + // the write is resolved (checked above). this test plan is not because of the expression. + val parsedPlan = OverwriteByExpression.byPosition(xRequiredTable, query, + LessThanOrEqual(UnresolvedAttribute(Seq("a")), Literal(15.0d))) + + assertNotResolved(parsedPlan) + assertAnalysisError(parsedPlan, Seq("cannot resolve", "`a`", "given input columns", "x, y")) + } +} + case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with NamedRelation { override def name: String = "table-name" } -class DataSourceV2AnalysisSuite extends AnalysisTest { +abstract class DataSourceV2AnalysisSuite extends AnalysisTest { val table = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType))).toAttributes) @@ -40,21 +117,25 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("y", DoubleType))).toAttributes) - test("Append.byName: basic behavior") { + def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan + + def byPosition(table: NamedRelation, query: LogicalPlan): LogicalPlan + + test("byName: basic behavior") { val query = TestRelation(table.schema.toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) checkAnalysis(parsedPlan, parsedPlan) assertResolved(parsedPlan) } - test("Append.byName: does not match by position") { + test("byName: does not match by position") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -62,12 +143,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'", "'y'")) } - test("Append.byName: case sensitive column resolution") { + test("byName: case sensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -76,7 +157,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { caseSensitive = true) } - test("Append.byName: case insensitive column resolution") { + test("byName: case insensitive column resolution") { val query = TestRelation(StructType(Seq( StructField("X", FloatType), // doesn't match case! StructField("y", FloatType))).toAttributes) @@ -84,8 +165,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val X = query.output.head val y = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(toLower(X), FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -96,7 +177,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: data columns are reordered by name") { + test("byName: data columns are reordered by name") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -105,8 +186,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byName(table, query) - val expectedPlan = AppendData.byName(table, + val parsedPlan = byName(table, query) + val expectedPlan = byName(table, Project(Seq( Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -117,26 +198,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail nullable data written to required columns") { - val parsedPlan = AppendData.byName(requiredTable, table) + test("byName: fail nullable data written to required columns") { + val parsedPlan = byName(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byName: allow required data written to nullable columns") { - val parsedPlan = AppendData.byName(table, requiredTable) + test("byName: allow required data written to nullable columns") { + val parsedPlan = byName(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byName: missing required columns cause failure and are identified by name") { + test("byName: missing required columns cause failure and are identified by name") { // missing required field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byName(requiredTable, query) + val parsedPlan = byName(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -144,12 +225,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: missing optional columns cause failure and are identified by name") { + test("byName: missing optional columns cause failure and are identified by name") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -157,8 +238,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'x'")) } - test("Append.byName: fail canWrite check") { - val parsedPlan = AppendData.byName(table, widerTable) + test("byName: fail canWrite check") { + val parsedPlan = byName(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -166,12 +247,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byName: insert safe cast") { + test("byName: insert safe cast") { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byName(widerTable, table) - val expectedPlan = AppendData.byName(widerTable, + val parsedPlan = byName(widerTable, table) + val expectedPlan = byName(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -182,13 +263,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byName: fail extra data fields") { + test("byName: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("x", FloatType), StructField("y", FloatType), StructField("z", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -197,7 +278,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'x', 'y', 'z'")) } - test("Append.byName: multiple field errors are reported") { + test("byName: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -206,7 +287,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(xRequiredTable, query) + val parsedPlan = byName(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -216,7 +297,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot find data for output column", "'y'")) } - test("Append.byPosition: basic behavior") { + test("byPosition: basic behavior") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType))).toAttributes) @@ -224,8 +305,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val a = query.output.head val b = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(a, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(b, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -236,7 +317,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: data columns are not reordered") { + test("byPosition: data columns are not reordered") { // out of order val query = TestRelation(StructType(Seq( StructField("y", FloatType), @@ -245,8 +326,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val y = query.output.head val x = query.output.last - val parsedPlan = AppendData.byPosition(table, query) - val expectedPlan = AppendData.byPosition(table, + val parsedPlan = byPosition(table, query) + val expectedPlan = byPosition(table, Project(Seq( Alias(Cast(y, FloatType, Some(conf.sessionLocalTimeZone)), "x")(), Alias(Cast(x, FloatType, Some(conf.sessionLocalTimeZone)), "y")()), @@ -257,26 +338,26 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail nullable data written to required columns") { - val parsedPlan = AppendData.byPosition(requiredTable, table) + test("byPosition: fail nullable data written to required columns") { + val parsedPlan = byPosition(requiredTable, table) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( "Cannot write incompatible data to table", "'table-name'", "Cannot write nullable values to non-null column", "'x'", "'y'")) } - test("Append.byPosition: allow required data written to nullable columns") { - val parsedPlan = AppendData.byPosition(table, requiredTable) + test("byPosition: allow required data written to nullable columns") { + val parsedPlan = byPosition(table, requiredTable) assertResolved(parsedPlan) checkAnalysis(parsedPlan, parsedPlan) } - test("Append.byPosition: missing required columns cause failure") { + test("byPosition: missing required columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType, nullable = false))).toAttributes) - val parsedPlan = AppendData.byPosition(requiredTable, query) + val parsedPlan = byPosition(requiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -285,12 +366,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: missing optional columns cause failure") { + test("byPosition: missing optional columns cause failure") { // missing optional field x val query = TestRelation(StructType(Seq( StructField("y", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, query) + val parsedPlan = byPosition(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -299,12 +380,12 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'y'")) } - test("Append.byPosition: fail canWrite check") { + test("byPosition: fail canWrite check") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) - val parsedPlan = AppendData.byPosition(table, widerTable) + val parsedPlan = byPosition(table, widerTable) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -312,7 +393,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "'y'", "DoubleType to FloatType")) } - test("Append.byPosition: insert safe cast") { + test("byPosition: insert safe cast") { val widerTable = TestRelation(StructType(Seq( StructField("a", DoubleType), StructField("b", DoubleType))).toAttributes) @@ -320,8 +401,8 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { val x = table.output.head val y = table.output.last - val parsedPlan = AppendData.byPosition(widerTable, table) - val expectedPlan = AppendData.byPosition(widerTable, + val parsedPlan = byPosition(widerTable, table) + val expectedPlan = byPosition(widerTable, Project(Seq( Alias(Cast(x, DoubleType, Some(conf.sessionLocalTimeZone)), "a")(), Alias(Cast(y, DoubleType, Some(conf.sessionLocalTimeZone)), "b")()), @@ -332,13 +413,13 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { assertResolved(expectedPlan) } - test("Append.byPosition: fail extra data fields") { + test("byPosition: fail extra data fields") { val query = TestRelation(StructType(Seq( StructField("a", FloatType), StructField("b", FloatType), StructField("c", FloatType))).toAttributes) - val parsedPlan = AppendData.byName(table, query) + val parsedPlan = byName(table, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( @@ -347,7 +428,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { "Data columns: 'a', 'b', 'c'")) } - test("Append.byPosition: multiple field errors are reported") { + test("byPosition: multiple field errors are reported") { val xRequiredTable = TestRelation(StructType(Seq( StructField("x", FloatType, nullable = false), StructField("y", DoubleType))).toAttributes) @@ -356,7 +437,7 @@ class DataSourceV2AnalysisSuite extends AnalysisTest { StructField("x", DoubleType), StructField("b", FloatType))).toAttributes) - val parsedPlan = AppendData.byPosition(xRequiredTable, query) + val parsedPlan = byPosition(xRequiredTable, query) assertNotResolved(parsedPlan) assertAnalysisError(parsedPlan, Seq( diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java index 296d3e47e732b..f10fd884daabe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java @@ -29,6 +29,9 @@ public interface SupportsPushDownFilters extends ScanBuilder { /** * Pushes down filters, and returns filters that need to be evaluated after scanning. + *

+ * Rows should be returned from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. */ Filter[] pushFilters(Filter[] filters); diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java new file mode 100644 index 0000000000000..8058964b662bd --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support dynamic partition overwrite. + *

+ * A write that dynamically overwrites partitions removes all existing data in each logical + * partition for which the write will commit new data. Any existing logical partition for which the + * write does not contain data will remain unchanged. + *

+ * This is provided to implement SQL compatible with Hive table operations but is not recommended. + * Instead, use the {@link SupportsOverwrite overwrite by filter API} to explicitly replace data. + */ +public interface SupportsDynamicOverwrite extends WriteBuilder { + /** + * Configures a write to dynamically replace partitions with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder overwriteDynamicPartitions(); +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java new file mode 100644 index 0000000000000..b443b3c3aeb4a --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +import org.apache.spark.sql.sources.AlwaysTrue$; +import org.apache.spark.sql.sources.Filter; + +/** + * Write builder trait for tables that support overwrite by filter. + *

+ * Overwriting data by filter will delete any data that matches the filter and replace it with data + * that is committed in the write. + */ +public interface SupportsOverwrite extends WriteBuilder, SupportsTruncate { + /** + * Configures a write to replace data matching the filters with data committed in the write. + *

+ * Rows must be deleted from the data source if and only if all of the filters match. That is, + * filters must be interpreted as ANDed together. + * + * @param filters filters used to match data to overwrite + * @return this write builder for method chaining + */ + WriteBuilder overwrite(Filter[] filters); + + @Override + default WriteBuilder truncate() { + return overwrite(new Filter[] { AlwaysTrue$.MODULE$ }); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java new file mode 100644 index 0000000000000..69c2ba5e01a49 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2.writer; + +/** + * Write builder trait for tables that support truncation. + *

+ * Truncation removes all data in a table and replaces it with data that is committed in the write. + */ +public interface SupportsTruncate extends WriteBuilder { + /** + * Configures a write to replace all existing data with data committed in the write. + * + * @return this write builder for method chaining + */ + WriteBuilder truncate(); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 47fb548ecd43c..b5cfa85f6fb21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -25,7 +25,8 @@ import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{EliminateSubqueryAliases, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan} +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} @@ -264,29 +265,38 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val dsOptions = new DataSourceOptions(options.asJava) provider.getTable(dsOptions) match { case table: SupportsBatchWrite => - if (mode == SaveMode.Append) { - val relation = DataSourceV2Relation.create(table, options) - runCommand(df.sparkSession, "save") { - AppendData.byName(relation, df.logicalPlan) - } - } else { - val writeBuilder = table.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(mode).buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) + lazy val relation = DataSourceV2Relation.create(table, options) + mode match { + case SaveMode.Append => + runCommand(df.sparkSession, "save") { + AppendData.byName(relation, df.logicalPlan) + } + + case SaveMode.Overwrite => + // truncate the table + runCommand(df.sparkSession, "save") { + OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) + } + + case _ => + table.newWriteBuilder(dsOptions) match { + case writeBuilder: SupportsSaveMode => + val write = writeBuilder.mode(mode) + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // It can only return null with `SupportsSaveMode`. We can clean it up after + // removing `SupportsSaveMode`. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } } - } - case _ => throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + case _ => + throw new AnalysisException( + s"data source ${table.name} does not support SaveMode $mode") + } } // Streaming also uses the data source V2 API. So it may be that the data source implements diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala index 273cc3b19302d..b73dc30d6f23c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceStrategy.scala @@ -529,6 +529,12 @@ object DataSourceStrategy { case expressions.Contains(a: Attribute, Literal(v: UTF8String, StringType)) => Some(sources.StringContains(a.name, v.toString)) + case expressions.Literal(true, BooleanType) => + Some(sources.AlwaysTrue) + + case expressions.Literal(false, BooleanType) => + Some(sources.AlwaysFalse) + case _ => None } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala new file mode 100644 index 0000000000000..c8542bfe5e59b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} + +object DataSourceV2Implicits { + implicit class TableHelper(table: Table) { + def asBatchReadable: SupportsBatchRead = { + table match { + case support: SupportsBatchRead => + support + case _ => + throw new AnalysisException(s"Table does not support batch reads: ${table.name}") + } + } + + def asBatchWritable: SupportsBatchWrite = { + table match { + case support: SupportsBatchWrite => + support + case _ => + throw new AnalysisException(s"Table does not support batch writes: ${table.name}") + } + } + } + + implicit class OptionsHelper(options: Map[String, String]) { + def toDataSourceOptions: DataSourceOptions = new DataSourceOptions(options.asJava) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 47cf26dc9481e..53677782c95f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -17,11 +17,6 @@ package org.apache.spark.sql.execution.datasources.v2 -import java.util.UUID - -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} @@ -30,7 +25,6 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.types.StructType /** * A logical plan representing a data source v2 table. @@ -45,26 +39,16 @@ case class DataSourceV2Relation( options: Map[String, String]) extends LeafNode with MultiInstanceRelation with NamedRelation { + import DataSourceV2Implicits._ + override def name: String = table.name() override def simpleString(maxFields: Int): String = { s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } - def newScanBuilder(): ScanBuilder = table match { - case s: SupportsBatchRead => - val dsOptions = new DataSourceOptions(options.asJava) - s.newScanBuilder(dsOptions) - case _ => throw new AnalysisException(s"Table is not readable: ${table.name()}") - } - - def newWriteBuilder(schema: StructType): WriteBuilder = table match { - case s: SupportsBatchWrite => - val dsOptions = new DataSourceOptions(options.asJava) - s.newWriteBuilder(dsOptions) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(schema) - case _ => throw new AnalysisException(s"Table is not writable: ${table.name()}") + def newScanBuilder(): ScanBuilder = { + table.asBatchReadable.newScanBuilder(options.toDataSourceOptions) } override def computeStats(): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d6d17d6df7b1b..55d7b0a18cbc8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -19,18 +19,18 @@ package org.apache.spark.sql.execution.datasources.v2 import scala.collection.mutable -import org.apache.spark.sql.{sources, AnalysisException, SaveMode, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, SubqueryExpression} +import org.apache.spark.sql.{AnalysisException, Strategy} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} +import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode -object DataSourceV2Strategy extends Strategy { +object DataSourceV2Strategy extends Strategy with PredicateHelper { /** * Pushes down filters to the data source reader @@ -100,6 +100,7 @@ object DataSourceV2Strategy extends Strategy { } } + import DataSourceV2Implicits._ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => @@ -146,14 +147,22 @@ object DataSourceV2Strategy extends Strategy { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - val writeBuilder = r.newWriteBuilder(query.schema) - writeBuilder match { - case s: SupportsSaveMode => - val write = s.mode(SaveMode.Append).buildForBatch() - assert(write != null) - WriteToDataSourceV2Exec(write, planLater(query)) :: Nil - case _ => throw new AnalysisException(s"data source ${r.name} does not support SaveMode") - } + AppendDataExec( + r.table.asBatchWritable, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => + // fail if any filter cannot be converted. correctness depends on removing all matching data. + val filters = splitConjunctivePredicates(deleteExpr).map { + filter => DataSourceStrategy.translateFilter(deleteExpr).getOrElse( + throw new AnalysisException(s"Cannot translate expression to source filter: $filter")) + }.toArray + + OverwriteByExpressionExec( + r.table.asBatchWritable, filters, r.options.toDataSourceOptions, planLater(query)) :: Nil + + case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => + OverwritePartitionsDynamicExec(r.table.asBatchWritable, + r.options.toDataSourceOptions, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 50c5e4f2ad7df..d7cb2457433b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -17,17 +17,22 @@ package org.apache.spark.sql.execution.datasources.v2 +import java.util.UUID + import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} -import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.sources.{AlwaysTrue, Filter} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchWrite} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -42,17 +47,137 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) } /** - * The physical plan for writing data into data source v2. + * Physical plan node for append into a v2 table. + * + * Rows in the output data set are appended. + */ +case class AppendDataExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsSaveMode => + builder.mode(SaveMode.Append).buildForBatch() + + case builder => + builder.buildForBatch() + } + doWrite(batchWrite) + } +} + +/** + * Physical plan node for overwrite into a v2 table. + * + * Overwrites data in a table matched by a set of filters. Rows matching all of the filters will be + * deleted and rows in the output data set are appended. + * + * This plan is used to implement SaveMode.Overwrite. The behavior of SaveMode.Overwrite is to + * truncate the table -- delete all rows -- and append the output data set. This uses the filter + * AlwaysTrue to delete all rows. */ -case class WriteToDataSourceV2Exec(batchWrite: BatchWrite, query: SparkPlan) - extends UnaryExecNode { +case class OverwriteByExpressionExec( + table: SupportsBatchWrite, + deleteWhere: Array[Filter], + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + private def isTruncate(filters: Array[Filter]): Boolean = { + filters.length == 1 && filters(0).isInstanceOf[AlwaysTrue] + } + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsTruncate if isTruncate(deleteWhere) => + builder.truncate().buildForBatch() + + case builder: SupportsSaveMode if isTruncate(deleteWhere) => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case builder: SupportsOverwrite => + builder.overwrite(deleteWhere).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +/** + * Physical plan node for dynamic partition overwrite into a v2 table. + * + * Dynamic partition overwrite is the behavior of Hive INSERT OVERWRITE ... PARTITION queries, and + * Spark INSERT OVERWRITE queries when spark.sql.sources.partitionOverwriteMode=dynamic. Each + * partition in the output data set replaces the corresponding existing partition in the table or + * creates a new partition. Existing partitions for which there is no data in the output data set + * are not modified. + */ +case class OverwritePartitionsDynamicExec( + table: SupportsBatchWrite, + writeOptions: DataSourceOptions, + query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { + + override protected def doExecute(): RDD[InternalRow] = { + val batchWrite = newWriteBuilder() match { + case builder: SupportsDynamicOverwrite => + builder.overwriteDynamicPartitions().buildForBatch() + + case builder: SupportsSaveMode => + builder.mode(SaveMode.Overwrite).buildForBatch() + + case _ => + throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + } + + doWrite(batchWrite) + } +} + +case class WriteToDataSourceV2Exec( + batchWrite: BatchWrite, + query: SparkPlan + ) extends V2TableWriteExec { + + import DataSourceV2Implicits._ + + def writeOptions: DataSourceOptions = Map.empty[String, String].toDataSourceOptions + + override protected def doExecute(): RDD[InternalRow] = { + doWrite(batchWrite) + } +} + +/** + * Helper for physical plans that build batch writes. + */ +trait BatchWriteHelper { + def table: SupportsBatchWrite + def query: SparkPlan + def writeOptions: DataSourceOptions + + def newWriteBuilder(): WriteBuilder = { + table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + } +} + +/** + * The base physical plan for writing data into data source v2. + */ +trait V2TableWriteExec extends UnaryExecNode { + def query: SparkPlan var commitProgress: Option[StreamWriterCommitProgress] = None override def child: SparkPlan = query override def output: Seq[Attribute] = Nil - override protected def doExecute(): RDD[InternalRow] = { + protected def doWrite(batchWrite: BatchWrite): RDD[InternalRow] = { val writerFactory = batchWrite.createBatchWriterFactory() val useCommitCoordinator = batchWrite.useCommitCoordinator val rdd = query.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala index 3f941cc6e1072..a1ab55a7185ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.sources -import org.apache.spark.annotation.Stable +import org.apache.spark.annotation.{Evolving, Stable} //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines all the filters that we can push down to the data sources. @@ -218,3 +218,27 @@ case class StringEndsWith(attribute: String, value: String) extends Filter { case class StringContains(attribute: String, value: String) extends Filter { override def references: Array[String] = Array(attribute) } + +/** + * A filter that always evaluates to `true`. + */ +@Evolving +case class AlwaysTrue() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysTrue extends AlwaysTrue { +} + +/** + * A filter that always evaluates to `false`. + */ +@Evolving +case class AlwaysFalse() extends Filter { + override def references: Array[String] = Array.empty +} + +@Evolving +object AlwaysFalse extends AlwaysFalse { +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 511fdfe5c23ac..6b5c45e40ab0c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -351,19 +351,21 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append mode") { + test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { withTempPath { file => val cls = classOf[SimpleWriteOnlyDataSource] val path = file.getCanonicalPath val df = spark.range(5).select('id as 'i, -'id as 'j) // non-append mode should not throw exception, as they don't access schema. df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("overwrite").save() df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append mode will access schema and should throw exception. + // append and overwrite modes will access the schema and should throw exception. intercept[SchemaReadAttemptException] { df.write.format(cls.getName).option("path", path).mode("append").save() } + intercept[SchemaReadAttemptException] { + df.write.format(cls.getName).option("path", path).mode("overwrite").save() + } } } } From af5d1875fb549063c59ee5c58c6cf7ecfecf9e3b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 18 Feb 2019 16:17:24 -0800 Subject: [PATCH 03/70] [SPARK-26785][SQL] data source v2 API refactor: streaming write ## What changes were proposed in this pull request? Continue the API refactor for streaming write, according to the [doc](https://docs.google.com/document/d/1vI26UEuDpVuOjWw4WPoH2T6y8WAekwtI7qoowhOFnI4/edit?usp=sharing). The major changes: 1. rename `StreamingWriteSupport` to `StreamingWrite` 2. add `WriteBuilder.buildForStreaming` 3. update existing sinks, to move the creation of `StreamingWrite` to `Table` ## How was this patch tested? existing tests Closes #23702 from cloud-fan/stream-write. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../sql/kafka010/KafkaSourceProvider.scala | 42 +++++---- ...upport.scala => KafkaStreamingWrite.scala} | 8 +- .../sql/sources/v2/SessionConfigSupport.java | 4 +- .../v2/StreamingWriteSupportProvider.java | 54 ------------ .../sql/sources/v2/SupportsBatchWrite.java | 2 +- .../sources/v2/SupportsStreamingWrite.java | 33 +++++++ .../spark/sql/sources/v2/TableProvider.java | 3 +- .../sql/sources/v2/writer/WriteBuilder.java | 9 +- .../v2/writer/WriterCommitMessage.java | 4 +- .../streaming/StreamingDataWriterFactory.java | 2 +- ...gWriteSupport.java => StreamingWrite.java} | 21 ++++- .../streaming/SupportsOutputMode.java} | 17 ++-- .../apache/spark/sql/DataFrameReader.scala | 2 +- .../datasources/noop/NoopDataSource.scala | 26 ++---- .../v2/DataSourceV2StringFormat.scala | 88 ------------------- .../datasources/v2/DataSourceV2Utils.scala | 43 ++++----- .../streaming/MicroBatchExecution.scala | 20 +++-- .../streaming/StreamingRelation.scala | 6 +- .../sql/execution/streaming/console.scala | 43 ++++++--- .../continuous/ContinuousExecution.scala | 25 +++--- .../continuous/EpochCoordinator.scala | 6 +- .../WriteToContinuousDataSource.scala | 6 +- .../WriteToContinuousDataSourceExec.scala | 13 +-- ...eWriteSupport.scala => ConsoleWrite.scala} | 6 +- ...rovider.scala => ForeachWriterTable.scala} | 76 +++++++++------- .../streaming/sources/MicroBatchWrite.scala | 4 +- .../sources/RateStreamProvider.scala | 3 +- .../sources/TextSocketSourceProvider.scala | 3 +- .../streaming/sources/memoryV2.scala | 42 ++++++--- .../sql/streaming/DataStreamReader.scala | 2 +- .../sql/streaming/DataStreamWriter.scala | 50 ++++++----- .../sql/streaming/StreamingQueryManager.scala | 4 +- ...pache.spark.sql.sources.DataSourceRegister | 2 +- .../streaming/MemorySinkV2Suite.scala | 6 +- .../sources/v2/DataSourceV2UtilsSuite.scala | 4 +- .../sources/v2/SimpleWritableDataSource.scala | 3 +- .../ContinuousQueuedDataReaderSuite.scala | 4 +- .../continuous/EpochCoordinatorSuite.scala | 6 +- .../sources/StreamingDataSourceV2Suite.scala | 70 +++++++++------ 39 files changed, 373 insertions(+), 389 deletions(-) rename external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/{KafkaStreamingWriteSupport.scala => KafkaStreamingWrite.scala} (95%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/{StreamingWriteSupport.java => StreamingWrite.java} (73%) rename sql/core/src/main/java/org/apache/spark/sql/sources/v2/{DataSourceV2.java => writer/streaming/SupportsOutputMode.java} (67%) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ConsoleWriteSupport.scala => ConsoleWrite.scala} (94%) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{ForeachWriteSupportProvider.scala => ForeachWriterTable.scala} (66%) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 9238899b0c00c..6994517b27d6a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -33,7 +33,8 @@ import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -47,7 +48,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSinkProvider with RelationProvider with CreatableRelationProvider - with StreamingWriteSupportProvider with TableProvider with Logging { import KafkaSourceProvider._ @@ -180,20 +180,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } } - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - import scala.collection.JavaConverters._ - - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - // We convert the options argument from V2 -> Java map -> scala mutable -> scala immutable. - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) - - new KafkaStreamingWriteSupport(topic, producerParams, schema) - } - private def strategy(caseInsensitiveParams: Map[String, String]) = caseInsensitiveParams.find(x => STRATEGY_OPTION_KEYS.contains(x._1)).get match { case ("assign", value) => @@ -365,7 +351,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { + with SupportsMicroBatchRead with SupportsContinuousRead with SupportsStreamingWrite { override def name(): String = s"Kafka $strategy" @@ -374,6 +360,28 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + import scala.collection.JavaConverters._ + + assert(inputSchema != null) + val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) + val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + new KafkaStreamingWrite(topic, producerParams, inputSchema) + } + } + } } class KafkaScan(options: DataSourceOptions) extends Scan { diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala similarity index 95% rename from external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala rename to external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala index 0d831c3884609..e3101e1572082 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWriteSupport.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaStreamingWrite.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.kafka010.KafkaWriter.validateQuery import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -33,18 +33,18 @@ import org.apache.spark.sql.types.StructType case object KafkaWriterCommitMessage extends WriterCommitMessage /** - * A [[StreamingWriteSupport]] for Kafka writing. Responsible for generating the writer factory. + * A [[StreamingWrite]] for Kafka writing. Responsible for generating the writer factory. * * @param topic The topic this writer is responsible for. If None, topic will be inferred from * a `topic` field in the incoming data. * @param producerParams Parameters for Kafka producers in each task. * @param schema The schema of the input data. */ -class KafkaStreamingWriteSupport( +class KafkaStreamingWrite( topic: Option[String], producerParams: ju.Map[String, Object], schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { validateQuery(schema.toAttributes, producerParams, topic) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java index c00abd9b685b5..d27fbfdd14617 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java @@ -20,12 +20,12 @@ import org.apache.spark.annotation.Evolving; /** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to + * A mix-in interface for {@link TableProvider}. Data sources can implement this interface to * propagate session configs with the specified key-prefix to all data source operations in this * session. */ @Evolving -public interface SessionConfigSupport extends DataSourceV2 { +public interface SessionConfigSupport extends TableProvider { /** * Key prefix of the session configs to propagate, which is usually the data source name. Spark diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java deleted file mode 100644 index 8ac9c51750865..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/StreamingWriteSupportProvider.java +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; -import org.apache.spark.sql.streaming.OutputMode; -import org.apache.spark.sql.types.StructType; - -/** - * A mix-in interface for {@link DataSourceV2}. Data sources can implement this interface to - * provide data writing ability for structured streaming. - * - * This interface is used to create {@link StreamingWriteSupport} instances when end users run - * {@code Dataset.writeStream.format(...).option(...).start()}. - */ -@Evolving -public interface StreamingWriteSupportProvider extends DataSourceV2, BaseStreamingSink { - - /** - * Creates a {@link StreamingWriteSupport} instance to save the data to this data source, which is - * called by Spark at the beginning of each streaming query. - * - * @param queryId A unique string for the writing query. It's possible that there are many - * writing queries running at the same time, and the returned - * {@link StreamingWriteSupport} can use this id to distinguish itself from others. - * @param schema the schema of the data to be written. - * @param mode the output mode which determines what successive epoch output means to this - * sink, please refer to {@link OutputMode} for more details. - * @param options the options for the returned data source writer, which is an immutable - * case-insensitive string-to-string map. - */ - StreamingWriteSupport createStreamingWriteSupport( - String queryId, - StructType schema, - OutputMode mode, - DataSourceOptions options); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java index 08caadd5308e6..b2cd97a2f5332 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java @@ -24,7 +24,7 @@ * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. *

* If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} * with {@link WriteBuilder#buildForBatch()} implemented. *

*/ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java new file mode 100644 index 0000000000000..1050d35250c1f --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.execution.streaming.BaseStreamingSink; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; + +/** + * An empty mix-in interface for {@link Table}, to indicate this table supports streaming write. + *

+ * If a {@link Table} implements this interface, the + * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} + * with {@link WriteBuilder#buildForStreaming()} implemented. + *

+ */ +@Evolving +public interface SupportsStreamingWrite extends SupportsWrite, BaseStreamingSink { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 855d5efe0c69f..a9b83b6de9950 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -29,8 +29,7 @@ *

*/ @Evolving -// TODO: do not extend `DataSourceV2`, after we finish the API refactor completely. -public interface TableProvider extends DataSourceV2 { +public interface TableProvider { /** * Return a {@link Table} instance to do read/write with user-specified options. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index e861c72af9e68..07529fe1dee91 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; /** @@ -64,6 +65,12 @@ default WriteBuilder withInputDataSchema(StructType schema) { * {@link SupportsSaveMode}. */ default BatchWrite buildForBatch() { - throw new UnsupportedOperationException("Batch scans are not supported"); + throw new UnsupportedOperationException(getClass().getName() + + " does not support batch write"); + } + + default StreamingWrite buildForStreaming() { + throw new UnsupportedOperationException(getClass().getName() + + " does not support streaming write"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java index 6334c8f643098..23e8580c404d4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java @@ -20,12 +20,12 @@ import java.io.Serializable; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport; +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; /** * A commit message returned by {@link DataWriter#commit()} and will be sent back to the driver side * as the input parameter of {@link BatchWrite#commit(WriterCommitMessage[])} or - * {@link StreamingWriteSupport#commit(long, WriterCommitMessage[])}. + * {@link StreamingWrite#commit(long, WriterCommitMessage[])}. * * This is an empty interface, data sources should define their own message class and use it when * generating messages at executor side and handling the messages at driver side. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java index 7d3d21cb2b637..af2f03c9d4192 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java @@ -26,7 +26,7 @@ /** * A factory of {@link DataWriter} returned by - * {@link StreamingWriteSupport#createStreamingWriterFactory()}, which is responsible for creating + * {@link StreamingWrite#createStreamingWriterFactory()}, which is responsible for creating * and initializing the actual data writer at executor side. * * Note that, the writer factory will be serialized and sent to executors, then the data writer diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java index 84cfbf2dda483..5617f1cdc0efc 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWriteSupport.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java @@ -22,13 +22,26 @@ import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage; /** - * An interface that defines how to write the data to data source for streaming processing. + * An interface that defines how to write the data to data source in streaming queries. * - * Streaming queries are divided into intervals of data called epochs, with a monotonically - * increasing numeric ID. This writer handles commits and aborts for each successive epoch. + * The writing procedure is: + * 1. Create a writer factory by {@link #createStreamingWriterFactory()}, serialize and send it to + * all the partitions of the input data(RDD). + * 2. For each epoch in each partition, create the data writer, and write the data of the epoch in + * the partition with this writer. If all the data are written successfully, call + * {@link DataWriter#commit()}. If exception happens during the writing, call + * {@link DataWriter#abort()}. + * 3. If writers in all partitions of one epoch are successfully committed, call + * {@link #commit(long, WriterCommitMessage[])}. If some writers are aborted, or the job failed + * with an unknown reason, call {@link #abort(long, WriterCommitMessage[])}. + * + * While Spark will retry failed writing tasks, Spark won't retry failed writing jobs. Users should + * do it manually in their Spark applications if they want to retry. + * + * Please refer to the documentation of commit/abort methods for detailed specifications. */ @Evolving -public interface StreamingWriteSupport { +public interface StreamingWrite { /** * Creates a writer factory which will be serialized and sent to executors. diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java similarity index 67% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java rename to sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java index 43bdcca70cb09..832dcfa145d1b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceV2.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java @@ -15,12 +15,15 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2; +package org.apache.spark.sql.sources.v2.writer.streaming; -import org.apache.spark.annotation.Evolving; +import org.apache.spark.annotation.Unstable; +import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.streaming.OutputMode; -/** - * TODO: remove it when we finish the API refactor for streaming write side. - */ -@Evolving -public interface DataSourceV2 {} +// TODO: remove it when we have `SupportsTruncate` +@Unstable +public interface SupportsOutputMode extends WriteBuilder { + + WriteBuilder outputMode(OutputMode mode); +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index a380a06cb942b..ffa19895ee3c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -205,7 +205,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { if (classOf[TableProvider].isAssignableFrom(cls)) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val pathsOption = { val objectMapper = new ObjectMapper() DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 452ebbbeb99c8..8f2072c586a94 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,30 +30,23 @@ import org.apache.spark.sql.types.StructType * This is no-op datasource. It does not do anything besides consuming its input. * This can be useful for benchmarking or to cache data without any additional overhead. */ -class NoopDataSource - extends DataSourceV2 - with TableProvider - with DataSourceRegister - with StreamingWriteSupportProvider { - +class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" override def getTable(options: DataSourceOptions): Table = NoopTable - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = NoopStreamingWriteSupport } -private[noop] object NoopTable extends Table with SupportsBatchWrite { +private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() } -private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsSaveMode { - override def buildForBatch(): BatchWrite = NoopBatchWrite +private[noop] object NoopWriteBuilder extends WriteBuilder + with SupportsSaveMode with SupportsOutputMode { override def mode(mode: SaveMode): WriteBuilder = this + override def outputMode(mode: OutputMode): WriteBuilder = this + override def buildForBatch(): BatchWrite = NoopBatchWrite + override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } private[noop] object NoopBatchWrite extends BatchWrite { @@ -72,7 +65,7 @@ private[noop] object NoopWriter extends DataWriter[InternalRow] { override def abort(): Unit = {} } -private[noop] object NoopStreamingWriteSupport extends StreamingWriteSupport { +private[noop] object NoopStreamingWrite extends StreamingWrite { override def createStreamingWriterFactory(): StreamingDataWriterFactory = NoopStreamingDataWriterFactory override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} @@ -85,4 +78,3 @@ private[noop] object NoopStreamingDataWriterFactory extends StreamingDataWriterF taskId: Long, epochId: Long): DataWriter[InternalRow] = NoopWriter } - diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala deleted file mode 100644 index f11703c8a2773..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2StringFormat.scala +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.datasources.v2 - -import org.apache.commons.lang3.StringUtils - -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.util.truncatedString -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.DataSourceV2 -import org.apache.spark.util.Utils - -/** - * A trait that can be used by data source v2 related query plans(both logical and physical), to - * provide a string format of the data source information for explain. - */ -trait DataSourceV2StringFormat { - - /** - * The instance of this data source implementation. Note that we only consider its class in - * equals/hashCode, not the instance itself. - */ - def source: DataSourceV2 - - /** - * The output of the data source reader, w.r.t. column pruning. - */ - def output: Seq[Attribute] - - /** - * The options for this data source reader. - */ - def options: Map[String, String] - - /** - * The filters which have been pushed to the data source. - */ - def pushedFilters: Seq[Expression] - - private def sourceName: String = source match { - case registered: DataSourceRegister => registered.shortName() - // source.getClass.getSimpleName can cause Malformed class name error, - // call safer `Utils.getSimpleName` instead - case _ => Utils.getSimpleName(source.getClass) - } - - def metadataString(maxFields: Int): String = { - val entries = scala.collection.mutable.ArrayBuffer.empty[(String, String)] - - if (pushedFilters.nonEmpty) { - entries += "Filters" -> pushedFilters.mkString("[", ", ", "]") - } - - // TODO: we should only display some standard options like path, table, etc. - if (options.nonEmpty) { - entries += "Options" -> Utils.redact(options).map { - case (k, v) => s"$k=$v" - }.mkString("[", ",", "]") - } - - val outputStr = truncatedString(output, "[", ", ", "]", maxFields) - - val entriesStr = if (entries.nonEmpty) { - truncatedString(entries.map { - case (key, value) => key + ": " + StringUtils.abbreviate(value, 100) - }, " (", ", ", ")", maxFields) - } else { - "" - } - - s"$sourceName$outputStr$entriesStr" - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala index e9cc3991155c4..30897d86f8179 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Utils.scala @@ -21,8 +21,7 @@ import java.util.regex.Pattern import org.apache.spark.internal.Logging import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceV2, SessionConfigSupport} +import org.apache.spark.sql.sources.v2.{SessionConfigSupport, TableProvider} private[sql] object DataSourceV2Utils extends Logging { @@ -34,34 +33,28 @@ private[sql] object DataSourceV2Utils extends Logging { * `spark.datasource.$keyPrefix`. A session config `spark.datasource.$keyPrefix.xxx -> yyy` will * be transformed into `xxx -> yyy`. * - * @param ds a [[DataSourceV2]] object + * @param source a [[TableProvider]] object * @param conf the session conf * @return an immutable map that contains all the extracted and transformed k/v pairs. */ - def extractSessionConfigs(ds: DataSourceV2, conf: SQLConf): Map[String, String] = ds match { - case cs: SessionConfigSupport => - val keyPrefix = cs.keyPrefix() - require(keyPrefix != null, "The data source config key prefix can't be null.") - - val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") - - conf.getAllConfs.flatMap { case (key, value) => - val m = pattern.matcher(key) - if (m.matches() && m.groupCount() > 0) { - Seq((m.group(1), value)) - } else { - Seq.empty + def extractSessionConfigs(source: TableProvider, conf: SQLConf): Map[String, String] = { + source match { + case cs: SessionConfigSupport => + val keyPrefix = cs.keyPrefix() + require(keyPrefix != null, "The data source config key prefix can't be null.") + + val pattern = Pattern.compile(s"^spark\\.datasource\\.$keyPrefix\\.(.+)") + + conf.getAllConfs.flatMap { case (key, value) => + val m = pattern.matcher(key) + if (m.matches() && m.groupCount() > 0) { + Seq((m.group(1), value)) + } else { + Seq.empty + } } - } - - case _ => Map.empty - } - def failForUserSpecifiedSchema[T](ds: DataSourceV2): T = { - val name = ds match { - case register: DataSourceRegister => register.shortName() - case _ => ds.getClass.getName + case _ => Map.empty } - throw new UnsupportedOperationException(name + " source does not support user-specified schema") } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 2c339759f95ba..cca279030dfa7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateCo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -513,13 +514,16 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: StreamingWriteSupportProvider => - val writer = s.createStreamingWriteSupport( - s"$runId", - newAttributePlan.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, writer), newAttributePlan) + case s: SupportsStreamingWrite => + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = s.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(newAttributePlan.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, streamingWrite), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -549,7 +553,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: StreamingWriteSupportProvider => + case _: SupportsStreamingWrite => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 83d38dcade7e6..1b7aa548e6d21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.sources.v2.{DataSourceV2, Table} +import org.apache.spark.sql.sources.v2.{Table, TableProvider} object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -86,13 +86,13 @@ case class StreamingExecutionRelation( // know at read time whether the query is continuous or not, so we need to be able to // swap a V1 relation back in. /** - * Used to link a [[DataSourceV2]] into a streaming + * Used to link a [[TableProvider]] into a streaming * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. This is only used for creating * a streaming [[org.apache.spark.sql.DataFrame]] from [[org.apache.spark.sql.DataFrameReader]], * and should be converted before passing to [[StreamExecution]]. */ case class StreamingRelationV2( - dataSource: DataSourceV2, + source: TableProvider, sourceName: String, table: Table, extraOptions: Map[String, String], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 9c5c16f4f5d13..348bc767b2c46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,10 +18,11 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql._ -import org.apache.spark.sql.execution.streaming.sources.ConsoleWriteSupport +import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -30,17 +31,12 @@ case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) override def schema: StructType = data.schema } -class ConsoleSinkProvider extends DataSourceV2 - with StreamingWriteSupportProvider +class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new ConsoleWriteSupport(schema, options) + override def getTable(options: DataSourceOptions): Table = { + ConsoleTable } def createRelation( @@ -60,3 +56,28 @@ class ConsoleSinkProvider extends DataSourceV2 def shortName(): String = "console" } + +object ConsoleTable extends Table with SupportsStreamingWrite { + + override def name(): String = "console" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + assert(inputSchema != null) + new ConsoleWrite(inputSchema, options) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index b22795d207760..20101c7fda320 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,8 +32,9 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider, SupportsContinuousRead} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsContinuousRead, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} +import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -42,7 +43,7 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: StreamingWriteSupportProvider, + sink: SupportsStreamingWrite, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -174,12 +175,15 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val writer = sink.createStreamingWriteSupport( - s"$runId", - withNewSources.schema, - outputMode, - new DataSourceOptions(extraOptions.asJava)) - val planWithSink = WriteToContinuousDataSource(writer, withNewSources) + // TODO: we should translate OutputMode to concrete write actions like truncate, but + // the truncate action is being developed in SPARK-26666. + val writeBuilder = sink.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(withNewSources.schema) + val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] + .outputMode(outputMode) + .buildForStreaming() + val planWithSink = WriteToContinuousDataSource(streamingWrite, withNewSources) reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( @@ -214,9 +218,8 @@ class ContinuousExecution( trigger.asInstanceOf[ContinuousTrigger].intervalMs.toString) // Use the parent Spark session for the endpoint since it's where this query ID is registered. - val epochEndpoint = - EpochCoordinatorRef.create( - writer, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + val epochEndpoint = EpochCoordinatorRef.create( + streamingWrite, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index d1bda79f4b6ef..a99842220424d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -25,7 +25,7 @@ import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeR import org.apache.spark.sql.SparkSession import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.util.RpcUtils private[continuous] sealed trait EpochCoordinatorMessage extends Serializable @@ -82,7 +82,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * Create a reference to a new [[EpochCoordinator]]. */ def create( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, epochCoordinatorId: String, @@ -115,7 +115,7 @@ private[sql] object EpochCoordinatorRef extends Logging { * have both committed and reported an end offset for a given epoch. */ private[continuous] class EpochCoordinator( - writeSupport: StreamingWriteSupport, + writeSupport: StreamingWrite, stream: ContinuousStream, query: ContinuousExecution, startEpoch: Long, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala index 7ad21cc304e7c..54f484c4adae3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSource.scala @@ -19,13 +19,13 @@ package org.apache.spark.sql.execution.streaming.continuous import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** * The logical plan for writing data in a continuous stream. */ -case class WriteToContinuousDataSource( - writeSupport: StreamingWriteSupport, query: LogicalPlan) extends LogicalPlan { +case class WriteToContinuousDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { override def children: Seq[LogicalPlan] = Seq(query) override def output: Seq[Attribute] = Nil } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala index 2178466d63142..2f3af6a6544c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/WriteToContinuousDataSourceExec.scala @@ -26,21 +26,22 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.streaming.StreamExecution -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite /** - * The physical plan for writing data into a continuous processing [[StreamingWriteSupport]]. + * The physical plan for writing data into a continuous processing [[StreamingWrite]]. */ -case class WriteToContinuousDataSourceExec(writeSupport: StreamingWriteSupport, query: SparkPlan) - extends UnaryExecNode with Logging { +case class WriteToContinuousDataSourceExec(write: StreamingWrite, query: SparkPlan) + extends UnaryExecNode with Logging { + override def child: SparkPlan = query override def output: Seq[Attribute] = Nil override protected def doExecute(): RDD[InternalRow] = { - val writerFactory = writeSupport.createStreamingWriterFactory() + val writerFactory = write.createStreamingWriterFactory() val rdd = new ContinuousWriteRDD(query.execute(), writerFactory) - logInfo(s"Start processing data source write support: $writeSupport. " + + logInfo(s"Start processing data source write support: $write. " + s"The input RDD has ${rdd.partitions.length} partitions.") EpochCoordinatorRef.get( sparkContext.getLocalProperty(ContinuousExecution.EPOCH_COORDINATOR_ID_KEY), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala similarity index 94% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index 833e62f35ede1..f2ff30bcf1bef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWriteSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -22,12 +22,12 @@ import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** Common methods used to create writes for the the console sink */ -class ConsoleWriteSupport(schema: StructType, options: DataSourceOptions) - extends StreamingWriteSupport with Logging { +class ConsoleWrite(schema: StructType, options: DataSourceOptions) + extends StreamingWrite with Logging { // Number of rows to display, by default 20 rows protected val numRowsToShow = options.getInt("numRows", 20) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala similarity index 66% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 4218fd51ad206..6fbb59c43625a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriteSupportProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -22,63 +22,73 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, StreamingWriteSupportProvider} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** - * A [[org.apache.spark.sql.sources.v2.DataSourceV2]] for forwarding data into the specified - * [[ForeachWriter]]. + * A write-only table for forwarding data into the specified [[ForeachWriter]]. * * @param writer The [[ForeachWriter]] to process all data. * @param converter An object to convert internal rows to target type T. Either it can be * a [[ExpressionEncoder]] or a direct converter function. * @tparam T The expected type of the sink. */ -case class ForeachWriteSupportProvider[T]( +case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends StreamingWriteSupportProvider { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new StreamingWriteSupport { - override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} - - override def createStreamingWriterFactory(): StreamingDataWriterFactory = { - val rowConverter: InternalRow => T = converter match { - case Left(enc) => - val boundEnc = enc.resolveAndBind( - schema.toAttributes, - SparkSession.getActiveSession.get.sessionState.analyzer) - boundEnc.fromRow - case Right(func) => - func - } - ForeachWriterFactory(writer, rowConverter) + extends Table with SupportsStreamingWrite { + + override def name(): String = "ForeachSink" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var inputSchema: StructType = _ + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this } - override def toString: String = "ForeachSink" + override def outputMode(mode: OutputMode): WriteBuilder = this + + override def buildForStreaming(): StreamingWrite = { + new StreamingWrite { + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = {} + + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + val rowConverter: InternalRow => T = converter match { + case Left(enc) => + val boundEnc = enc.resolveAndBind( + inputSchema.toAttributes, + SparkSession.getActiveSession.get.sessionState.analyzer) + boundEnc.fromRow + case Right(func) => + func + } + ForeachWriterFactory(writer, rowConverter) + } + } + } } } } -object ForeachWriteSupportProvider { +object ForeachWriterTable { def apply[T]( writer: ForeachWriter[T], - encoder: ExpressionEncoder[T]): ForeachWriteSupportProvider[_] = { + encoder: ExpressionEncoder[T]): ForeachWriterTable[_] = { writer match { case pythonWriter: PythonForeachWriter => - new ForeachWriteSupportProvider[UnsafeRow]( + new ForeachWriterTable[UnsafeRow]( pythonWriter, Right((x: InternalRow) => x.asInstanceOf[UnsafeRow])) case _ => - new ForeachWriteSupportProvider[T](writer, Left(encoder)) + new ForeachWriterTable[T](writer, Left(encoder)) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala index 143235efee81d..f3951897ea747 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/MicroBatchWrite.scala @@ -19,14 +19,14 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} /** * A [[BatchWrite]] used to hook V2 stream writers into a microbatch plan. It implements * the non-streaming interface, forwarding the epoch ID determined at construction to a wrapped * streaming write support. */ -class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWriteSupport) extends BatchWrite { +class MicroBatchWrite(eppchId: Long, val writeSupport: StreamingWrite) extends BatchWrite { override def commit(messages: Array[WriterCommitMessage]): Unit = { writeSupport.commit(eppchId, messages) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 075c6b9362ba2..3a0082536512d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -40,8 +40,7 @@ import org.apache.spark.sql.types._ * generated rows. The source will try its best to reach `rowsPerSecond`, but the query may * be resource constrained, and `numPartitions` can be tweaked to help reach the desired speed. */ -class RateStreamProvider extends DataSourceV2 - with TableProvider with DataSourceRegister { +class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ override def getTable(options: DataSourceOptions): Table = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index c3b24a8f65dd9..8ac5bfc307aa3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -31,8 +31,7 @@ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} -class TextSocketSourceProvider extends DataSourceV2 - with TableProvider with DataSourceRegister with Logging { +class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { private def checkParameters(params: DataSourceOptions): Unit = { logWarning("The socket source should not be used for production applications! " + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index c50dc7bcb8da1..3fc2cbe0fde57 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -32,9 +32,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, StreamingWriteSupportProvider} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWriteSupport} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -42,15 +42,31 @@ import org.apache.spark.sql.types.StructType * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider - with MemorySinkBase with Logging { - - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - new MemoryStreamingWriteSupport(this, mode, schema) +class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Logging { + + override def name(): String = "MemorySinkV2" + + override def schema(): StructType = StructType(Nil) + + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + new WriteBuilder with SupportsOutputMode { + private var mode: OutputMode = _ + private var inputSchema: StructType = _ + + override def outputMode(mode: OutputMode): WriteBuilder = { + this.mode = mode + this + } + + override def withInputDataSchema(schema: StructType): WriteBuilder = { + this.inputSchema = schema + this + } + + override def buildForStreaming(): StreamingWrite = { + new MemoryStreamingWrite(MemorySinkV2.this, mode, inputSchema) + } + } } private case class AddedData(batchId: Long, data: Array[Row]) @@ -122,9 +138,9 @@ class MemorySinkV2 extends DataSourceV2 with StreamingWriteSupportProvider case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} -class MemoryStreamingWriteSupport( +class MemoryStreamingWrite( val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) - extends StreamingWriteSupport { + extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { MemoryWriterFactory(outputMode, schema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index a10bd2218eb38..81bffc32027a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -173,7 +173,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo ds match { case provider: TableProvider => val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - ds = provider, conf = sparkSession.sessionState.conf) + source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new DataSourceOptions(options.asJava) val table = userSpecifiedSchema match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index ea596ba728c19..984199488fa7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, TableProvider} /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -278,7 +278,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { query } else if (source == "foreach") { assertNotPartitioned("foreach") - val sink = ForeachWriteSupportProvider[T](foreachWriter, ds.exprEnc) + val sink = ForeachWriterTable[T](foreachWriter, ds.exprEnc) df.sparkSession.sessionState.streamingQueryManager.startQuery( extraOptions.get("queryName"), extraOptions.get("checkpointLocation"), @@ -304,30 +304,29 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { useTempCheckpointLocation = true, trigger = trigger) } else { - val ds = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) + val cls = DataSource.lookupDataSource(source, df.sparkSession.sessionState.conf) val disabledSources = df.sparkSession.sqlContext.conf.disabledV2StreamingWriters.split(",") - var options = extraOptions.toMap - val sink = ds.getConstructor().newInstance() match { - case w: StreamingWriteSupportProvider - if !disabledSources.contains(w.getClass.getCanonicalName) => - val sessionOptions = DataSourceV2Utils.extractSessionConfigs( - w, df.sparkSession.sessionState.conf) - options = sessionOptions ++ extraOptions - w - case _ => - val ds = DataSource( - df.sparkSession, - className = source, - options = options, - partitionColumns = normalizedParCols.getOrElse(Nil)) - ds.createSink(outputMode) + val useV1Source = disabledSources.contains(cls.getCanonicalName) + + val sink = if (classOf[TableProvider].isAssignableFrom(cls) && !useV1Source) { + val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] + val sessionOptions = DataSourceV2Utils.extractSessionConfigs( + source = provider, conf = df.sparkSession.sessionState.conf) + val options = sessionOptions ++ extraOptions + val dsOptions = new DataSourceOptions(options.asJava) + provider.getTable(dsOptions) match { + case s: SupportsStreamingWrite => s + case _ => createV1Sink() + } + } else { + createV1Sink() } df.sparkSession.sessionState.streamingQueryManager.startQuery( - options.get("queryName"), - options.get("checkpointLocation"), + extraOptions.get("queryName"), + extraOptions.get("checkpointLocation"), df, - options, + extraOptions.toMap, sink, outputMode, useTempCheckpointLocation = source == "console" || source == "noop", @@ -336,6 +335,15 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } } + private def createV1Sink(): BaseStreamingSink = { + val ds = DataSource( + df.sparkSession, + className = source, + options = extraOptions.toMap, + partitionColumns = normalizedParCols.getOrElse(Nil)) + ds.createSink(outputMode) + } + /** * Sets the output of the streaming query to be processed using the provided writer object. * object. See [[org.apache.spark.sql.ForeachWriter]] for more details on the lifecycle and diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 881cd96cc9dc9..e6773c5cc3bd4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.StreamingWriteSupportProvider +import org.apache.spark.sql.sources.v2.SupportsStreamingWrite import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -254,7 +254,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: StreamingWriteSupportProvider, trigger: ContinuousTrigger) => + case (v2Sink: SupportsStreamingWrite, trigger: ContinuousTrigger) => if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } diff --git a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister index a36b0cfa6ff18..914af589384df 100644 --- a/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister +++ b/sql/core/src/test/resources/META-INF/services/org.apache.spark.sql.sources.DataSourceRegister @@ -9,6 +9,6 @@ org.apache.spark.sql.streaming.sources.FakeReadMicroBatchOnly org.apache.spark.sql.streaming.sources.FakeReadContinuousOnly org.apache.spark.sql.streaming.sources.FakeReadBothModes org.apache.spark.sql.streaming.sources.FakeReadNeitherMode -org.apache.spark.sql.streaming.sources.FakeWriteSupportProvider +org.apache.spark.sql.streaming.sources.FakeWriteOnly org.apache.spark.sql.streaming.sources.FakeNoWrite org.apache.spark.sql.streaming.sources.FakeWriteSupportProviderV1Fallback diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index 61857365ac989..e804377540517 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -43,9 +43,9 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 - val writeSupport = new MemoryStreamingWriteSupport( + val write = new MemoryStreamingWrite( sink, OutputMode.Append(), new StructType().add("i", "int")) - writeSupport.commit(0, + write.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), @@ -53,7 +53,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { )) assert(sink.latestBatchId.contains(0)) assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - writeSupport.commit(19, + write.commit(19, Array( MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), MemoryWriterCommitMessage(0, Seq(Row(33))) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala index f903c17923d0f..0b1e3b5fb076d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2UtilsSuite.scala @@ -33,8 +33,8 @@ class DataSourceV2UtilsSuite extends SparkFunSuite { conf.setConfString(s"spark.sql.$keyPrefix.config.name", "false") conf.setConfString("spark.datasource.another.config.name", "123") conf.setConfString(s"spark.datasource.$keyPrefix.", "123") - val cs = classOf[DataSourceV2WithSessionConfig].getConstructor().newInstance() - val confs = DataSourceV2Utils.extractSessionConfigs(cs.asInstanceOf[DataSourceV2], conf) + val source = new DataSourceV2WithSessionConfig + val confs = DataSourceV2Utils.extractSessionConfigs(source, conf) assert(confs.size == 2) assert(confs.keySet.filter(_.startsWith("spark.datasource")).size == 0) assert(confs.keySet.filter(_.startsWith("not.exist.prefix")).size == 0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index daca65fd1ad2c..c56a54598cd4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -38,8 +38,7 @@ import org.apache.spark.util.SerializableConfiguration * Each task writes data to `target/_temporary/uniqueId/$jobId-$partitionId-$attemptNumber`. * Each job moves files from `target/_temporary/uniqueId/` to `target`. */ -class SimpleWritableDataSource extends DataSourceV2 - with TableProvider with SessionConfigSupport { +class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { private val tableSchema = new StructType().add("i", "long").add("j", "long") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala index d3d210c02e90d..bad22590807a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueuedDataReaderSuite.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousPartitionReader, ContinuousStream, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.types.{DataType, IntegerType, StructType} @@ -43,7 +43,7 @@ class ContinuousQueuedDataReaderSuite extends StreamTest with MockitoSugar { override def beforeEach(): Unit = { super.beforeEach() epochEndpoint = EpochCoordinatorRef.create( - mock[StreamingWriteSupport], + mock[StreamingWrite], mock[ContinuousStream], mock[ContinuousExecution], coordinatorId, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index a0b56ec17f0be..f74285f4b0fb3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.test.TestSparkSession class EpochCoordinatorSuite @@ -40,13 +40,13 @@ class EpochCoordinatorSuite private var epochCoordinator: RpcEndpointRef = _ - private var writeSupport: StreamingWriteSupport = _ + private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ override def beforeEach(): Unit = { val stream = mock[ContinuousStream] - writeSupport = mock[StreamingWriteSupport] + writeSupport = mock[StreamingWrite] query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 62f166602941c..c841793fdd4a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWriteSupport +import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -71,13 +71,10 @@ trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder } -trait FakeStreamingWriteSupportProvider extends StreamingWriteSupportProvider { - override def createStreamingWriteSupport( - queryId: String, - schema: StructType, - mode: OutputMode, - options: DataSourceOptions): StreamingWriteSupport = { - LastWriteOptions.options = options +trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { throw new IllegalStateException("fake sink - cannot actually write") } } @@ -129,20 +126,33 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { } } -class FakeWriteSupportProvider +class FakeWriteOnly extends DataSourceRegister - with FakeStreamingWriteSupportProvider + with TableProvider with SessionConfigSupport { override def shortName(): String = "fake-write-microbatch-continuous" override def keyPrefix: String = shortName() + + override def getTable(options: DataSourceOptions): Table = { + LastWriteOptions.options = options + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } -class FakeNoWrite extends DataSourceRegister { +class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" + override def getTable(options: DataSourceOptions): Table = { + new Table { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } - case class FakeWriteV1FallbackException() extends Exception class FakeSink extends Sink { @@ -150,17 +160,24 @@ class FakeSink extends Sink { } class FakeWriteSupportProviderV1Fallback extends DataSourceRegister - with FakeStreamingWriteSupportProvider with StreamSinkProvider { + with TableProvider with StreamSinkProvider { override def createSink( - sqlContext: SQLContext, - parameters: Map[String, String], - partitionColumns: Seq[String], - outputMode: OutputMode): Sink = { + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { new FakeSink() } override def shortName(): String = "fake-write-v1-fallback" + + override def getTable(options: DataSourceOptions): Table = { + new Table with FakeStreamingWriteTable { + override def name(): String = "fake" + override def schema(): StructType = StructType(Nil) + } + } } object LastReadOptions { @@ -260,7 +277,7 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery( "fake-read-microbatch-continuous", "fake-write-v1-fallback", Trigger.Once()) { v2Query => assert(v2Query.asInstanceOf[StreamingQueryWrapper].streamingQuery.sink - .isInstanceOf[FakeWriteSupportProviderV1Fallback]) + .isInstanceOf[Table]) } // Ensure we create a V1 sink with the config. Note the config is a comma separated @@ -319,19 +336,20 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { - val table = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() + .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) + + val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) - val writeSource = DataSource.lookupDataSource(write, spark.sqlContext.conf). - getConstructor().newInstance() - (table, writeSource, trigger) match { + (sourceTable, sinkTable, trigger) match { // Valid microbatch queries. - case (_: SupportsMicroBatchRead, _: StreamingWriteSupportProvider, t) + case (_: SupportsMicroBatchRead, _: SupportsStreamingWrite, t) if !t.isInstanceOf[ContinuousTrigger] => testPositiveCase(read, write, trigger) // Valid continuous queries. - case (_: SupportsContinuousRead, _: StreamingWriteSupportProvider, + case (_: SupportsContinuousRead, _: SupportsStreamingWrite, _: ContinuousTrigger) => testPositiveCase(read, write, trigger) @@ -342,12 +360,12 @@ class StreamingDataSourceV2Suite extends StreamTest { s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[StreamingWriteSupportProvider] => + case (_, w, _) if !w.isInstanceOf[SupportsStreamingWrite] => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") // Invalid - trigger is continuous but reader is not - case (r, _: StreamingWriteSupportProvider, _: ContinuousTrigger) + case (r, _: SupportsStreamingWrite, _: ContinuousTrigger) if !r.isInstanceOf[SupportsContinuousRead] => testNegativeCase(read, write, trigger, s"Data source $read does not support continuous processing") From 167ffec6b0c943aabea0a9a868bab214bdb2cdfd Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 8 Mar 2019 19:31:49 +0800 Subject: [PATCH 04/70] [SPARK-24252][SQL] Add v2 catalog plugin system ## What changes were proposed in this pull request? This adds a v2 API for adding new catalog plugins to Spark. * Catalog implementations extend `CatalogPlugin` and are loaded via reflection, similar to data sources * `Catalogs` loads and initializes catalogs using configuration from a `SQLConf` * `CaseInsensitiveStringMap` is used to pass configuration to `CatalogPlugin` via `initialize` Catalogs are configured by adding config properties starting with `spark.sql.catalog.(name)`. The name property must specify a class that implements `CatalogPlugin`. Other properties under the namespace (`spark.sql.catalog.(name).(prop)`) are passed to the provider during initialization along with the catalog name. This replaces #21306, which will be implemented in two multiple parts: the catalog plugin system (this commit) and specific catalog APIs, like `TableCatalog`. ## How was this patch tested? Added test suites for `CaseInsensitiveStringMap` and for catalog loading. Closes #23915 from rdblue/SPARK-24252-add-v2-catalog-plugins. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../spark/sql/catalog/v2/CatalogPlugin.java | 61 +++++ .../apache/spark/sql/catalog/v2/Catalogs.java | 109 +++++++++ .../sql/util/CaseInsensitiveStringMap.java | 110 +++++++++ .../sql/catalog/v2/CatalogLoadingSuite.java | 208 ++++++++++++++++++ .../util/CaseInsensitiveStringMapSuite.java | 48 ++++ .../org/apache/spark/sql/SparkSession.scala | 8 + 6 files changed, 544 insertions(+) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java create mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java new file mode 100644 index 0000000000000..5d4995a05d233 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/CatalogPlugin.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; + +/** + * A marker interface to provide a catalog implementation for Spark. + *

+ * Implementations can provide catalog functions by implementing additional interfaces for tables, + * views, and functions. + *

+ * Catalog implementations must implement this marker interface to be loaded by + * {@link Catalogs#load(String, SQLConf)}. The loader will instantiate catalog classes using the + * required public no-arg constructor. After creating an instance, it will be configured by calling + * {@link #initialize(String, CaseInsensitiveStringMap)}. + *

+ * Catalog implementations are registered to a name by adding a configuration option to Spark: + * {@code spark.sql.catalog.catalog-name=com.example.YourCatalogClass}. All configuration properties + * in the Spark configuration that share the catalog name prefix, + * {@code spark.sql.catalog.catalog-name.(key)=(value)} will be passed in the case insensitive + * string map of options in initialization with the prefix removed. + * {@code name}, is also passed and is the catalog's name; in this case, "catalog-name". + */ +@Experimental +public interface CatalogPlugin { + /** + * Called to initialize configuration. + *

+ * This method is called once, just after the provider is instantiated. + * + * @param name the name used to identify and load this catalog + * @param options a case-insensitive string map of configuration + */ + void initialize(String name, CaseInsensitiveStringMap options); + + /** + * Called to get this catalog's name. + *

+ * This method is only called after {@link #initialize(String, CaseInsensitiveStringMap)} is + * called to pass the catalog's name. + */ + String name(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java new file mode 100644 index 0000000000000..efae26636a4bc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.annotation.Private; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.apache.spark.util.Utils; + +import java.util.Map; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static scala.collection.JavaConverters.mapAsJavaMapConverter; + +@Private +public class Catalogs { + private Catalogs() { + } + + /** + * Load and configure a catalog by name. + *

+ * This loads, instantiates, and initializes the catalog plugin for each call; it does not cache + * or reuse instances. + * + * @param name a String catalog name + * @param conf a SQLConf + * @return an initialized CatalogPlugin + * @throws SparkException If the plugin class cannot be found or instantiated + */ + public static CatalogPlugin load(String name, SQLConf conf) throws SparkException { + String pluginClassName = conf.getConfString("spark.sql.catalog." + name, null); + if (pluginClassName == null) { + throw new SparkException(String.format( + "Catalog '%s' plugin class not found: spark.sql.catalog.%s is not defined", name, name)); + } + + ClassLoader loader = Utils.getContextOrSparkClassLoader(); + + try { + Class pluginClass = loader.loadClass(pluginClassName); + + if (!CatalogPlugin.class.isAssignableFrom(pluginClass)) { + throw new SparkException(String.format( + "Plugin class for catalog '%s' does not implement CatalogPlugin: %s", + name, pluginClassName)); + } + + CatalogPlugin plugin = CatalogPlugin.class.cast(pluginClass.newInstance()); + + plugin.initialize(name, catalogOptions(name, conf)); + + return plugin; + + } catch (ClassNotFoundException e) { + throw new SparkException(String.format( + "Cannot find catalog plugin class for catalog '%s': %s", name, pluginClassName)); + + } catch (IllegalAccessException e) { + throw new SparkException(String.format( + "Failed to call public no-arg constructor for catalog '%s': %s", name, pluginClassName), + e); + + } catch (InstantiationException e) { + throw new SparkException(String.format( + "Failed while instantiating plugin for catalog '%s': %s", name, pluginClassName), + e.getCause()); + } + } + + /** + * Extracts a named catalog's configuration from a SQLConf. + * + * @param name a catalog name + * @param conf a SQLConf + * @return a case insensitive string map of options starting with spark.sql.catalog.(name). + */ + private static CaseInsensitiveStringMap catalogOptions(String name, SQLConf conf) { + Map allConfs = mapAsJavaMapConverter(conf.getAllConfs()).asJava(); + Pattern prefix = Pattern.compile("^spark\\.sql\\.catalog\\." + name + "\\.(.+)"); + + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + for (Map.Entry entry : allConfs.entrySet()) { + Matcher matcher = prefix.matcher(entry.getKey()); + if (matcher.matches() && matcher.groupCount() > 0) { + options.put(matcher.group(1), entry.getValue()); + } + } + + return options; + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java new file mode 100644 index 0000000000000..8c5a6c61d8658 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util; + +import org.apache.spark.annotation.Experimental; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Set; + +/** + * Case-insensitive map of string keys to string values. + *

+ * This is used to pass options to v2 implementations to ensure consistent case insensitivity. + *

+ * Methods that return keys in this map, like {@link #entrySet()} and {@link #keySet()}, return + * keys converted to lower case. + */ +@Experimental +public class CaseInsensitiveStringMap implements Map { + + public static CaseInsensitiveStringMap empty() { + return new CaseInsensitiveStringMap(); + } + + private final Map delegate; + + private CaseInsensitiveStringMap() { + this.delegate = new HashMap<>(); + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public boolean isEmpty() { + return delegate.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return delegate.containsKey(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public boolean containsValue(Object value) { + return delegate.containsValue(value); + } + + @Override + public String get(Object key) { + return delegate.get(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public String put(String key, String value) { + return delegate.put(key.toLowerCase(Locale.ROOT), value); + } + + @Override + public String remove(Object key) { + return delegate.remove(key.toString().toLowerCase(Locale.ROOT)); + } + + @Override + public void putAll(Map m) { + for (Map.Entry entry : m.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public void clear() { + delegate.clear(); + } + + @Override + public Set keySet() { + return delegate.keySet(); + } + + @Override + public Collection values() { + return delegate.values(); + } + + @Override + public Set> entrySet() { + return delegate.entrySet(); + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java new file mode 100644 index 0000000000000..2f55da83e2a49 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java @@ -0,0 +1,208 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.SparkException; +import org.apache.spark.sql.internal.SQLConf; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.junit.Assert; +import org.junit.Test; + +import java.util.concurrent.Callable; + +public class CatalogLoadingSuite { + @Test + public void testLoad() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName()); + + CatalogPlugin plugin = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null plugin", plugin); + Assert.assertEquals("Plugin should have correct implementation", + TestCatalogPlugin.class, plugin.getClass()); + + TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin; + Assert.assertEquals("Options should contain no keys", 0, testPlugin.options.size()); + Assert.assertEquals("Catalog should have correct name", "test-name", testPlugin.name()); + } + + @Test + public void testInitializationOptions() throws SparkException { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.test-name", TestCatalogPlugin.class.getCanonicalName()); + conf.setConfString("spark.sql.catalog.test-name.name", "not-catalog-name"); + conf.setConfString("spark.sql.catalog.test-name.kEy", "valUE"); + + CatalogPlugin plugin = Catalogs.load("test-name", conf); + Assert.assertNotNull("Should instantiate a non-null plugin", plugin); + Assert.assertEquals("Plugin should have correct implementation", + TestCatalogPlugin.class, plugin.getClass()); + + TestCatalogPlugin testPlugin = (TestCatalogPlugin) plugin; + + Assert.assertEquals("Options should contain only two keys", 2, testPlugin.options.size()); + Assert.assertEquals("Options should contain correct value for name (not overwritten)", + "not-catalog-name", testPlugin.options.get("name")); + Assert.assertEquals("Options should contain correct value for key", + "valUE", testPlugin.options.get("key")); + } + + @Test + public void testLoadWithoutConfig() { + SQLConf conf = new SQLConf(); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that implementation is not configured", + exc.getMessage() + .contains("plugin class not found: spark.sql.catalog.missing is not defined")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("missing")); + } + + @Test + public void testLoadMissingClass() { + SQLConf conf = new SQLConf(); + conf.setConfString("spark.sql.catalog.missing", "com.example.NoSuchCatalogPlugin"); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + + Assert.assertTrue("Should complain that the class is not found", + exc.getMessage().contains("Cannot find catalog plugin class")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("missing")); + Assert.assertTrue("Should identify the missing class", + exc.getMessage().contains("com.example.NoSuchCatalogPlugin")); + } + + @Test + public void testLoadNonCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = InvalidCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that class does not implement CatalogPlugin", + exc.getMessage().contains("does not implement CatalogPlugin")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", + exc.getMessage().contains(invalidClassName)); + } + + @Test + public void testLoadConstructorFailureCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = ConstructorFailureCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + RuntimeException exc = intercept(RuntimeException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should have expected error message", + exc.getMessage().contains("Expected failure")); + } + + @Test + public void testLoadAccessErrorCatalogPlugin() { + SQLConf conf = new SQLConf(); + String invalidClassName = AccessErrorCatalogPlugin.class.getCanonicalName(); + conf.setConfString("spark.sql.catalog.invalid", invalidClassName); + + SparkException exc = intercept(SparkException.class, () -> Catalogs.load("invalid", conf)); + + Assert.assertTrue("Should complain that no public constructor is provided", + exc.getMessage().contains("Failed to call public no-arg constructor for catalog")); + Assert.assertTrue("Should identify the catalog by name", + exc.getMessage().contains("invalid")); + Assert.assertTrue("Should identify the class", + exc.getMessage().contains(invalidClassName)); + } + + @SuppressWarnings("unchecked") + public static E intercept(Class expected, Callable callable) { + try { + callable.call(); + Assert.fail("No exception was thrown, expected: " + + expected.getName()); + } catch (Exception actual) { + try { + Assert.assertEquals(expected, actual.getClass()); + return (E) actual; + } catch (AssertionError e) { + e.addSuppressed(actual); + throw e; + } + } + // Compiler doesn't catch that Assert.fail will always throw an exception. + throw new UnsupportedOperationException("[BUG] Should not reach this statement"); + } +} + +class TestCatalogPlugin implements CatalogPlugin { + String name = null; + CaseInsensitiveStringMap options = null; + + TestCatalogPlugin() { + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + this.name = name; + this.options = options; + } + + @Override + public String name() { + return name; + } +} + +class ConstructorFailureCatalogPlugin implements CatalogPlugin { // fails in its constructor + ConstructorFailureCatalogPlugin() { + throw new RuntimeException("Expected failure."); + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + } + + @Override + public String name() { + return null; + } +} + +class AccessErrorCatalogPlugin implements CatalogPlugin { // no public constructor + private AccessErrorCatalogPlugin() { + } + + @Override + public void initialize(String name, CaseInsensitiveStringMap options) { + } + + @Override + public String name() { + return null; + } +} + +class InvalidCatalogPlugin { // doesn't implement CatalogPlugin + public void initialize(CaseInsensitiveStringMap options) { + } +} diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java new file mode 100644 index 0000000000000..76392777d42a4 --- /dev/null +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util; + +import org.junit.Assert; +import org.junit.Test; + +import java.util.HashSet; +import java.util.Set; + +public class CaseInsensitiveStringMapSuite { + @Test + public void testPutAndGet() { + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + options.put("kEy", "valUE"); + + Assert.assertEquals("Should return correct value for lower-case key", + "valUE", options.get("key")); + Assert.assertEquals("Should return correct value for upper-case key", + "valUE", options.get("KEY")); + } + + @Test + public void testKeySet() { + CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + options.put("kEy", "valUE"); + + Set expectedKeySet = new HashSet<>(); + expectedKeySet.add("key"); + + Assert.assertEquals("Should return lower-case key set", expectedKeySet, options.keySet()); + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index c5d14dfffd9b2..ff5ca2ac1111a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -21,6 +21,7 @@ import java.io.Closeable import java.util.concurrent.atomic.AtomicReference import scala.collection.JavaConverters._ +import scala.collection.mutable import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal @@ -31,6 +32,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalog.Catalog +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Catalogs} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.encoders._ @@ -619,6 +621,12 @@ class SparkSession private( */ @transient lazy val catalog: Catalog = new CatalogImpl(self) + @transient private lazy val catalogs = new mutable.HashMap[String, CatalogPlugin]() + + private[sql] def catalog(name: String): CatalogPlugin = synchronized { + catalogs.getOrElseUpdate(name, Catalogs.load(name, sessionState.conf)) + } + /** * Returns the specified table/view as a `DataFrame`. * From 3da923b9364f0041f876b18fe5f9a51aca058014 Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Thu, 21 Mar 2019 18:04:50 -0700 Subject: [PATCH 05/70] [SPARK-26946][SQL] Identifiers for multi-catalog ## What changes were proposed in this pull request? - Support N-part identifier in SQL - N-part identifier extractor in Analyzer ## How was this patch tested? - A new unit test suite ResolveMultipartRelationSuite - CatalogLoadingSuite rblue cloud-fan mccheah Closes #23848 from jzhuge/SPARK-26946. Authored-by: John Zhuge Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 8 ++ .../apache/spark/sql/catalog/v2/Catalogs.java | 8 +- .../spark/sql/catalog/v2/Identifier.java | 41 ++++++++ .../spark/sql/catalog/v2/IdentifierImpl.java | 45 +++++++++ .../catalog/v2/CatalogNotFoundException.scala | 28 ++++++ .../spark/sql/catalog/v2/LookupCatalog.scala | 74 ++++++++++++++ .../sql/catalyst/analysis/Analyzer.scala | 11 ++- .../sql/catalyst/parser/AstBuilder.scala | 13 +++ .../sql/catalyst/parser/ParseDriver.scala | 7 ++ .../sql/catalyst/parser/ParserInterface.scala | 6 ++ .../sql/catalog/v2/CatalogLoadingSuite.java | 3 +- .../v2/ResolveMultipartIdentifierSuite.scala | 99 +++++++++++++++++++ .../sql/SparkSessionExtensionSuite.scala | 3 + 13 files changed, 340 insertions(+), 6 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index b39681d886c5c..76a13c5e2478f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -58,6 +58,10 @@ singleTableIdentifier : tableIdentifier EOF ; +singleMultipartIdentifier + : multipartIdentifier EOF + ; + singleFunctionIdentifier : functionIdentifier EOF ; @@ -539,6 +543,10 @@ rowFormat (NULL DEFINED AS nullDefinedAs=STRING)? #rowFormatDelimited ; +multipartIdentifier + : parts+=identifier ('.' parts+=identifier)* + ; + tableIdentifier : (db=identifier '.')? table=identifier ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java index efae26636a4bc..bcb1f56789daf 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java @@ -43,12 +43,14 @@ private Catalogs() { * @param name a String catalog name * @param conf a SQLConf * @return an initialized CatalogPlugin - * @throws SparkException If the plugin class cannot be found or instantiated + * @throws CatalogNotFoundException if the plugin class cannot be found + * @throws SparkException if the plugin class cannot be instantiated */ - public static CatalogPlugin load(String name, SQLConf conf) throws SparkException { + public static CatalogPlugin load(String name, SQLConf conf) + throws CatalogNotFoundException, SparkException { String pluginClassName = conf.getConfString("spark.sql.catalog." + name, null); if (pluginClassName == null) { - throw new SparkException(String.format( + throw new CatalogNotFoundException(String.format( "Catalog '%s' plugin class not found: spark.sql.catalog.%s is not defined", name, name)); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java new file mode 100644 index 0000000000000..3e697c1945bfc --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Identifier.java @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * Identifies an object in a catalog. + */ +@Experimental +public interface Identifier { + + static Identifier of(String[] namespace, String name) { + return new IdentifierImpl(namespace, name); + } + + /** + * @return the namespace in the catalog + */ + String[] namespace(); + + /** + * @return the object name + */ + String name(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java new file mode 100644 index 0000000000000..8874faa71b5bb --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java @@ -0,0 +1,45 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * An {@link Identifier} implementation. + */ +@Experimental +class IdentifierImpl implements Identifier { + + private String[] namespace; + private String name; + + IdentifierImpl(String[] namespace, String name) { + this.namespace = namespace; + this.name = name; + } + + @Override + public String[] namespace() { + return namespace; + } + + @Override + public String name() { + return name; + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala new file mode 100644 index 0000000000000..86de1c9285b73 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogNotFoundException.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import org.apache.spark.SparkException +import org.apache.spark.annotation.Experimental + +@Experimental +class CatalogNotFoundException(message: String, cause: Throwable) + extends SparkException(message, cause) { + + def this(message: String) = this(message, null) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala new file mode 100644 index 0000000000000..932d32022702b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -0,0 +1,74 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import org.apache.spark.annotation.Experimental +import org.apache.spark.sql.catalyst.TableIdentifier + +/** + * A trait to encapsulate catalog lookup function and helpful extractors. + */ +@Experimental +trait LookupCatalog { + + def lookupCatalog: Option[(String) => CatalogPlugin] = None + + type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) + + /** + * Extract catalog plugin and identifier from a multi-part identifier. + */ + object CatalogObjectIdentifier { + def unapply(parts: Seq[String]): Option[CatalogObjectIdentifier] = lookupCatalog.map { lookup => + parts match { + case Seq(name) => + (None, Identifier.of(Array.empty, name)) + case Seq(catalogName, tail @ _*) => + try { + val catalog = lookup(catalogName) + (Some(catalog), Identifier.of(tail.init.toArray, tail.last)) + } catch { + case _: CatalogNotFoundException => + (None, Identifier.of(parts.init.toArray, parts.last)) + } + } + } + } + + /** + * Extract legacy table identifier from a multi-part identifier. + * + * For legacy support only. Please use + * [[org.apache.spark.sql.catalog.v2.LookupCatalog.CatalogObjectIdentifier]] in DSv2 code paths. + */ + object AsTableIdentifier { + def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { + case CatalogObjectIdentifier(None, ident) => + ident.namespace match { + case Array() => + Some(TableIdentifier(ident.name)) + case Array(database) => + Some(TableIdentifier(ident.name, Some(database))) + case _ => + None + } + case _ => + None + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 0e95c10065676..b60dd272c7a08 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,6 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -95,13 +96,19 @@ object AnalysisContext { class Analyzer( catalog: SessionCatalog, conf: SQLConf, - maxIterations: Int) - extends RuleExecutor[LogicalPlan] with CheckAnalysis { + maxIterations: Int, + override val lookupCatalog: Option[(String) => CatalogPlugin] = None) + extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } + def this(lookupCatalog: Option[(String) => CatalogPlugin], catalog: SessionCatalog, + conf: SQLConf) = { + this(catalog, conf, conf.optimizerMaxIterations, lookupCatalog) + } + def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { AnalysisHelper.markInAnalyzer { val analyzed = executeAndTrack(plan, tracker) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index a27c6d3c3671c..aa6d8cf7e5ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -86,6 +86,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitFunctionIdentifier(ctx.functionIdentifier) } + override def visitSingleMultipartIdentifier( + ctx: SingleMultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + visitMultipartIdentifier(ctx.multipartIdentifier) + } + override def visitSingleDataType(ctx: SingleDataTypeContext): DataType = withOrigin(ctx) { visitSparkDataType(ctx.dataType) } @@ -953,6 +958,14 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText)) } + /** + * Create a multi-part identifier. + */ + override def visitMultipartIdentifier( + ctx: MultipartIdentifierContext): Seq[String] = withOrigin(ctx) { + ctx.parts.asScala.map(_.getText) + } + /* ******************************************************************************************** * Expression parsing * ******************************************************************************************** */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index 2128a10d0b1bc..ffc64f78e3003 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -57,6 +57,13 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { } } + /** Creates a multi-part identifier for a given SQL string */ + override def parseMultipartIdentifier(sqlText: String): Seq[String] = { + parse(sqlText) { parser => + astBuilder.visitSingleMultipartIdentifier(parser.singleMultipartIdentifier()) + } + } + /** * Creates StructType for a given SQL string, which is a comma separated list of field * definitions which will preserve the correct Hive metadata. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala index 75240d2196222..77e357ad073da 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserInterface.scala @@ -52,6 +52,12 @@ trait ParserInterface { @throws[ParseException]("Text cannot be parsed to a FunctionIdentifier") def parseFunctionIdentifier(sqlText: String): FunctionIdentifier + /** + * Parse a string to a multi-part identifier. + */ + @throws[ParseException]("Text cannot be parsed to a multi-part identifier") + def parseMultipartIdentifier(sqlText: String): Seq[String] + /** * Parse a string to a [[StructType]]. The passed SQL string should be a comma separated list * of field definitions which will preserve the correct Hive metadata. diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java index 2f55da83e2a49..326b12f3618d3 100644 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java +++ b/sql/catalyst/src/test/java/org/apache/spark/sql/catalog/v2/CatalogLoadingSuite.java @@ -66,7 +66,8 @@ public void testInitializationOptions() throws SparkException { public void testLoadWithoutConfig() { SQLConf conf = new SQLConf(); - SparkException exc = intercept(SparkException.class, () -> Catalogs.load("missing", conf)); + SparkException exc = intercept(CatalogNotFoundException.class, + () -> Catalogs.load("missing", conf)); Assert.assertTrue("Should complain that implementation is not configured", exc.getMessage() diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala new file mode 100644 index 0000000000000..0f2d67eaa9b20 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog.v2 + +import org.scalatest.Matchers._ + +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private class TestCatalogPlugin(override val name: String) extends CatalogPlugin { + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit +} + +class ResolveMultipartIdentifierSuite extends AnalysisTest { + import CatalystSqlParser._ + + private val analyzer = makeAnalyzer(caseSensitive = false) + + private val catalogs = Seq("prod", "test").map(name => name -> new TestCatalogPlugin(name)).toMap + + private def lookupCatalog(catalog: String): CatalogPlugin = + catalogs.getOrElse(catalog, throw new CatalogNotFoundException("Not found")) + + private def makeAnalyzer(caseSensitive: Boolean) = { + val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) + new Analyzer(Some(lookupCatalog _), null, conf) + } + + override protected def getAnalyzer(caseSensitive: Boolean) = analyzer + + private def checkResolution(sqlText: String, expectedCatalog: Option[CatalogPlugin], + expectedNamespace: Array[String], expectedName: String): Unit = { + + import analyzer.CatalogObjectIdentifier + val CatalogObjectIdentifier(catalog, ident) = parseMultipartIdentifier(sqlText) + catalog shouldEqual expectedCatalog + ident.namespace shouldEqual expectedNamespace + ident.name shouldEqual expectedName + } + + private def checkTableResolution(sqlText: String, + expectedIdent: Option[TableIdentifier]): Unit = { + + import analyzer.AsTableIdentifier + parseMultipartIdentifier(sqlText) match { + case AsTableIdentifier(ident) => + assert(Some(ident) === expectedIdent) + case _ => + assert(None === expectedIdent) + } + } + + test("resolve multipart identifier") { + checkResolution("tbl", None, Array.empty, "tbl") + checkResolution("db.tbl", None, Array("db"), "tbl") + checkResolution("prod.func", catalogs.get("prod"), Array.empty, "func") + checkResolution("ns1.ns2.tbl", None, Array("ns1", "ns2"), "tbl") + checkResolution("prod.db.tbl", catalogs.get("prod"), Array("db"), "tbl") + checkResolution("test.db.tbl", catalogs.get("test"), Array("db"), "tbl") + checkResolution("test.ns1.ns2.ns3.tbl", + catalogs.get("test"), Array("ns1", "ns2", "ns3"), "tbl") + checkResolution("`db.tbl`", None, Array.empty, "db.tbl") + checkResolution("parquet.`file:/tmp/db.tbl`", None, Array("parquet"), "file:/tmp/db.tbl") + checkResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, + Array("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json") + } + + test("resolve table identifier") { + checkTableResolution("tbl", Some(TableIdentifier("tbl"))) + checkTableResolution("db.tbl", Some(TableIdentifier("tbl", Some("db")))) + checkTableResolution("prod.func", None) + checkTableResolution("ns1.ns2.tbl", None) + checkTableResolution("prod.db.tbl", None) + checkTableResolution("`db.tbl`", Some(TableIdentifier("db.tbl"))) + checkTableResolution("parquet.`file:/tmp/db.tbl`", + Some(TableIdentifier("file:/tmp/db.tbl", Some("parquet")))) + checkTableResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", + Some(TableIdentifier("s3://buck/tmp/abc.json", Some("org.apache.spark.sql.json")))) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala index 9f33feb1950c7..881268440ccd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionExtensionSuite.scala @@ -234,6 +234,9 @@ case class MyParser(spark: SparkSession, delegate: ParserInterface) extends Pars override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = delegate.parseFunctionIdentifier(sqlText) + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + delegate.parseMultipartIdentifier(sqlText) + override def parseTableSchema(sqlText: String): StructType = delegate.parseTableSchema(sqlText) From 85d0f089e5376fc773561d283950815b8b6d7e84 Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Sun, 24 Mar 2019 09:05:41 -0500 Subject: [PATCH 06/70] [SPARK-27250][TEST-MAVEN][BUILD] Scala 2.11 maven compile should target Java 1.8 ## What changes were proposed in this pull request? Fix Scala 2.11 maven build issue after merging SPARK-26946. ## How was this patch tested? Maven Scala 2.11 and 2.12 builds with `-Phadoop-provided -Phadoop-2.7 -Pyarn -Phive -Phive-thriftserver`. Closes #24184 from jzhuge/SPARK-26946-1. Authored-by: John Zhuge Signed-off-by: Sean Owen --- pom.xml | 1 + 1 file changed, 1 insertion(+) diff --git a/pom.xml b/pom.xml index 5a23ffbc41dd3..9f5699582f5af 100644 --- a/pom.xml +++ b/pom.xml @@ -2381,6 +2381,7 @@ -feature -explaintypes -Yno-adapted-args + -target:jvm-1.8 -Xms1024m From 4db1e195e12860c4cdc4058c8650ced315741e18 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Mon, 25 Feb 2019 16:20:06 -0800 Subject: [PATCH 07/70] [SPARK-26673][FOLLOWUP][SQL] File Source V2: check existence of output path before delete it ## What changes were proposed in this pull request? This is a followup PR to resolve comment: https://github.com/apache/spark/pull/23601#pullrequestreview-207101115 When Spark writes DataFrame with "overwrite" mode, it deletes the output path before actual writes. To safely handle the case that the output path doesn't exist, it is suggested to follow the V1 code by checking the existence. ## How was this patch tested? Apply https://github.com/apache/spark/pull/23836 and run unit tests Closes #23889 from gengliangwang/checkFileBeforeOverwrite. Authored-by: Gengliang Wang Signed-off-by: gatorsmile --- .../sql/execution/datasources/v2/FileWriteBuilder.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index ce9b52f29d7bd..4da4eb333e617 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -16,6 +16,7 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.io.IOException import java.util.UUID import scala.collection.JavaConverters._ @@ -83,7 +84,9 @@ abstract class FileWriteBuilder(options: DataSourceOptions) null case SaveMode.Overwrite => - committer.deleteWithJob(fs, path, true) + if (fs.exists(path) && !committer.deleteWithJob(fs, path, true)) { + throw new IOException(s"Unable to clear directory $path prior to writing to it") + } committer.setupJob(job) new FileBatchWrite(job, description, committer) From 273330163823d3b0b3e23ade544263179e16c98f Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Tue, 26 Feb 2019 14:10:54 +0800 Subject: [PATCH 08/70] [SPARK-26952][SQL] Row count statics should respect the data reported by data source ## What changes were proposed in this pull request? In data source v2, if the data source scan implemented `SupportsReportStatistics`. `DataSourceV2Relation` should respect the row count reported by the data source. ## How was this patch tested? New UT test. Closes #23853 from ConeyLiu/report-row-count. Authored-by: Xianyang Liu Signed-off-by: Wenchen Fan --- .../datasources/v2/DataSourceV2Relation.scala | 23 ++++++- .../v2/JavaReportStatisticsDataSource.java | 65 +++++++++++++++++++ .../sql/sources/v2/DataSourceV2Suite.scala | 46 ++++++++++++- 3 files changed, 130 insertions(+), 4 deletions(-) create mode 100644 sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 53677782c95f4..891694be46291 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader._ +import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ @@ -56,7 +56,7 @@ case class DataSourceV2Relation( scan match { case r: SupportsReportStatistics => val statistics = r.estimateStatistics() - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -89,7 +89,7 @@ case class StreamingDataSourceV2Relation( override def computeStats(): Statistics = scan match { case r: SupportsReportStatistics => val statistics = r.estimateStatistics() - Statistics(sizeInBytes = statistics.sizeInBytes().orElse(conf.defaultSizeInBytes)) + DataSourceV2Relation.transformV2Stats(statistics, None, conf.defaultSizeInBytes) case _ => Statistics(sizeInBytes = conf.defaultSizeInBytes) } @@ -100,4 +100,21 @@ object DataSourceV2Relation { val output = table.schema().toAttributes DataSourceV2Relation(table, output, options) } + + /** + * This is used to transform data source v2 statistics to logical.Statistics. + */ + def transformV2Stats( + v2Statistics: V2Statistics, + defaultRowCount: Option[BigInt], + defaultSizeInBytes: Long): Statistics = { + val numRows: Option[BigInt] = if (v2Statistics.numRows().isPresent) { + Some(v2Statistics.numRows().getAsLong) + } else { + defaultRowCount + } + Statistics( + sizeInBytes = v2Statistics.sizeInBytes().orElse(defaultSizeInBytes), + rowCount = numRows) + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java new file mode 100644 index 0000000000000..bbc8492ec4e16 --- /dev/null +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java @@ -0,0 +1,65 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package test.org.apache.spark.sql.sources.v2; + +import java.util.OptionalLong; + +import org.apache.spark.sql.sources.v2.DataSourceOptions; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableProvider; +import org.apache.spark.sql.sources.v2.reader.InputPartition; +import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.sources.v2.reader.Statistics; +import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics; + +public class JavaReportStatisticsDataSource implements TableProvider { + class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { + @Override + public Statistics estimateStatistics() { + return new Statistics() { + @Override + public OptionalLong sizeInBytes() { + return OptionalLong.of(80); + } + + @Override + public OptionalLong numRows() { + return OptionalLong.of(10); + } + }; + } + + @Override + public InputPartition[] planInputPartitions() { + InputPartition[] partitions = new InputPartition[2]; + partitions[0] = new JavaRangeInputPartition(0, 5); + partitions[1] = new JavaRangeInputPartition(5, 10); + return partitions; + } + } + + @Override + public Table getTable(DataSourceOptions options) { + return new JavaSimpleBatchTable() { + @Override + public ScanBuilder newScanBuilder(DataSourceOptions options) { + return new MyScanBuilder(); + } + }; + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 6b5c45e40ab0c..b8572448f736e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.File +import java.util.OptionalLong import test.org.apache.spark.sql.sources.v2._ @@ -182,6 +183,24 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } + test ("statistics report data source") { + Seq(classOf[ReportStatisticsDataSource], classOf[JavaReportStatisticsDataSource]).foreach { + cls => + withClue(cls.getName) { + val df = spark.read.format(cls.getName).load() + val logical = df.queryExecution.optimizedPlan.collect { + case d: DataSourceV2Relation => d + }.head + + val statics = logical.computeStats() + assert(statics.rowCount.isDefined && statics.rowCount.get === 10, + "Row count statics should be reported by data source") + assert(statics.sizeInBytes === 80, + "Size in bytes statics should be reported by data source") + } + } + } + test("SPARK-23574: no shuffle exchange with single partition") { val df = spark.read.format(classOf[SimpleSinglePartitionSource].getName).load().agg(count("*")) assert(df.queryExecution.executedPlan.collect { case e: Exchange => e }.isEmpty) @@ -621,7 +640,6 @@ object ColumnarReaderFactory extends PartitionReaderFactory { } } - class PartitionAwareDataSource extends TableProvider { class MyScanBuilder extends SimpleScanBuilder @@ -689,3 +707,29 @@ class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { } } } + +class ReportStatisticsDataSource extends TableProvider { + + class MyScanBuilder extends SimpleScanBuilder + with SupportsReportStatistics { + override def estimateStatistics(): Statistics = { + new Statistics { + override def sizeInBytes(): OptionalLong = OptionalLong.of(80) + + override def numRows(): OptionalLong = OptionalLong.of(10) + } + } + + override def planInputPartitions(): Array[InputPartition] = { + Array(RangeInputPartition(0, 5), RangeInputPartition(5, 10)) + } + } + + override def getTable(options: DataSourceOptions): Table = { + new SimpleBatchTable { + override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + new MyScanBuilder + } + } + } +} From 4094211060d0bd52bb8b660bb15471d4bfefc108 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Fri, 15 Feb 2019 14:57:23 +0800 Subject: [PATCH 09/70] [SPARK-26871][SQL] File Source V2: avoid creating unnecessary FileIndex in the write path ## What changes were proposed in this pull request? In https://github.com/apache/spark/pull/23383, the file source V2 framework is implemented. In the PR, `FileIndex` is created as a member of `FileTable`, so that we can implement partition pruning like https://github.com/apache/spark/commit/0f9fcabb4ac2e8afec14d010e86467372a85d334 in the future(As data source V2 catalog is under development, partition pruning is removed from the PR) However, after write path of file source V2 is implemented, I find that a simple write will create an unnecessary `FileIndex`, which is required by `FileTable`. This is a sort of regression. And we can see there is a warning message when writing to ORC files ``` WARN InMemoryFileIndex: The directory file:/tmp/foo was not found. Was it deleted very recently? ``` This PR is to make `FileIndex` as a lazy value in `FileTable`, so that we can avoid creating unnecessary `FileIndex` in the write path. ## How was this patch tested? Existing unit test Closes #23774 from gengliangwang/moveFileIndexInV2. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../datasources/FallbackOrcDataSourceV2.scala | 2 +- .../datasources/v2/FileDataSourceV2.scala | 20 ++----------------- .../execution/datasources/v2/FileTable.scala | 17 +++++++++++++--- .../datasources/v2/orc/OrcDataSourceV2.scala | 6 ++---- .../datasources/v2/orc/OrcTable.scala | 5 ++--- 5 files changed, 21 insertions(+), 29 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index 254c09001f7ec..e22d6a6d399a5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -35,7 +35,7 @@ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPl override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => val v1FileFormat = new OrcFileFormat - val relation = HadoopFsRelation(table.getFileIndex, table.getFileIndex.partitionSchema, + val relation = HadoopFsRelation(table.fileIndex, table.fileIndex.partitionSchema, table.schema(), None, v1FileFormat, d.options)(sparkSession) i.copy(table = LogicalRelation(relation)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index a0c932cbb0e09..06c57066aa240 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,13 +16,10 @@ */ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - -import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister -import org.apache.spark.sql.sources.v2.{DataSourceOptions, DataSourceV2, SupportsBatchRead, TableProvider} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.sources.v2.TableProvider /** * A base interface for data source v2 implementations of the built-in file-based data sources. @@ -38,17 +35,4 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { def fallBackFileFormat: Class[_ <: FileFormat] lazy val sparkSession = SparkSession.active - - def getFileIndex( - options: DataSourceOptions, - userSpecifiedSchema: Option[StructType]): PartitioningAwareFileIndex = { - val filePaths = options.paths() - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) - val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - new InMemoryFileIndex(sparkSession, rootPathsSpecified, - options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 0dbef145f7326..21d3e5e29cfb5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -16,19 +16,30 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} import org.apache.spark.sql.types.StructType abstract class FileTable( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead with SupportsBatchWrite { - def getFileIndex: PartitioningAwareFileIndex = this.fileIndex + lazy val fileIndex: PartitioningAwareFileIndex = { + val filePaths = options.paths() + val hadoopConf = + sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, + checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) + val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) + new InMemoryFileIndex(sparkSession, rootPathsSpecified, + options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { inferSchema(fileIndex.allFiles()) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index db1f2f7934221..74739b4fe2d48 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -34,13 +34,11 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def getTable(options: DataSourceOptions): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, None) - OrcTable(tableName, sparkSession, fileIndex, None) + OrcTable(tableName, sparkSession, options, None) } override def getTable(options: DataSourceOptions, schema: StructType): Table = { val tableName = getTableName(options) - val fileIndex = getFileIndex(options, Some(schema)) - OrcTable(tableName, sparkSession, fileIndex, Some(schema)) + OrcTable(tableName, sparkSession, options, Some(schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index b467e505f1bac..249df8b8622fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable import org.apache.spark.sql.sources.v2.DataSourceOptions @@ -29,9 +28,9 @@ import org.apache.spark.sql.types.StructType case class OrcTable( name: String, sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex, + options: DataSourceOptions, userSpecifiedSchema: Option[StructType]) - extends FileTable(sparkSession, fileIndex, userSpecifiedSchema) { + extends FileTable(sparkSession, options, userSpecifiedSchema) { override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder = new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) From caa5fab94afbc30c34345e24f9562d87d419ad5a Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Sat, 16 Feb 2019 17:11:36 +0800 Subject: [PATCH 10/70] [SPARK-26744][SQL] Support schema validation in FileDataSourceV2 framework ## What changes were proposed in this pull request? The file source has a schema validation feature, which validates 2 schemas: 1. the user-specified schema when reading. 2. the schema of input data when writing. If a file source doesn't support the schema, we can fail the query earlier. This PR is to implement the same feature in the `FileDataSourceV2` framework. Comparing to `FileFormat`, `FileDataSourceV2` has multiple layers. The API is added in two places: 1. Read path: the table schema is determined in `TableProvider.getTable`. The actual read schema can be a subset of the table schema. This PR proposes to validate the actual read schema in `FileScan`. 2. Write path: validate the actual output schema in `FileWriteBuilder`. ## How was this patch tested? Unit test Closes #23714 from gengliangwang/schemaValidationV2. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../execution/datasources/v2/FileScan.scala | 33 +++- .../datasources/v2/FileWriteBuilder.scala | 24 ++- .../datasources/v2/orc/OrcDataSourceV2.scala | 19 ++- .../datasources/v2/orc/OrcScan.scala | 10 +- .../datasources/v2/orc/OrcWriteBuilder.scala | 6 + .../spark/sql/FileBasedDataSourceSuite.scala | 152 ++++++++++-------- 6 files changed, 167 insertions(+), 77 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala index 3615b15be6fd5..bdd6a48df20ce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileScan.scala @@ -18,15 +18,16 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.hadoop.fs.Path -import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.PartitionedFileUtil import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, Scan} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} abstract class FileScan( sparkSession: SparkSession, - fileIndex: PartitioningAwareFileIndex) extends Scan with Batch { + fileIndex: PartitioningAwareFileIndex, + readSchema: StructType) extends Scan with Batch { /** * Returns whether a file with `path` could be split or not. */ @@ -34,6 +35,22 @@ abstract class FileScan( false } + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + protected def partitions: Seq[FilePartition] = { val selectedPartitions = fileIndex.listFiles(Seq.empty, Seq.empty) val maxSplitBytes = FilePartition.maxSplitBytes(sparkSession, selectedPartitions) @@ -57,5 +74,13 @@ abstract class FileScan( partitions.toArray } - override def toBatch: Batch = this + override def toBatch: Batch = { + readSchema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } + this + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index 4da4eb333e617..75c922424e8ef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration abstract class FileWriteBuilder(options: DataSourceOptions) @@ -107,12 +107,34 @@ abstract class FileWriteBuilder(options: DataSourceOptions) options: Map[String, String], dataSchema: StructType): OutputWriterFactory + /** + * Returns whether this format supports the given [[DataType]] in write path. + * By default all data types are supported. + */ + def supportsDataType(dataType: DataType): Boolean = true + + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def formatName(): String = "ORC" + * }}} + */ + def formatName: String + private def validateInputs(): Unit = { assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") assert(mode != null, "Missing save mode") assert(options.paths().length == 1) DataSource.validateSchema(schema) + schema.foreach { field => + if (!supportsDataType(field.dataType)) { + throw new AnalysisException( + s"$formatName data source does not support ${field.dataType.catalogString} data type.") + } + } } private def getJobInstance(hadoopConf: Configuration, path: Path): Job = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index 74739b4fe2d48..f279af49ba9cf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -20,7 +20,7 @@ import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table} -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types._ class OrcDataSourceV2 extends FileDataSourceV2 { @@ -42,3 +42,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 { OrcTable(tableName, sparkSession, options, Some(schema)) } } + +object OrcDataSourceV2 { + def supportsDataType(dataType: DataType): Boolean = dataType match { + case _: AtomicType => true + + case st: StructType => st.forall { f => supportsDataType(f.dataType) } + + case ArrayType(elementType, _) => supportsDataType(elementType) + + case MapType(keyType, valueType, _) => + supportsDataType(keyType) && supportsDataType(valueType) + + case udt: UserDefinedType[_] => supportsDataType(udt.sqlType) + + case _ => false + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala index a792ad318b398..3c5dc1f50d7e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScan.scala @@ -23,7 +23,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.v2.FileScan import org.apache.spark.sql.sources.v2.reader.PartitionReaderFactory -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.util.SerializableConfiguration case class OrcScan( @@ -31,7 +31,7 @@ case class OrcScan( hadoopConf: Configuration, fileIndex: PartitioningAwareFileIndex, dataSchema: StructType, - readSchema: StructType) extends FileScan(sparkSession, fileIndex) { + readSchema: StructType) extends FileScan(sparkSession, fileIndex, readSchema) { override def isSplitable(path: Path): Boolean = true override def createReaderFactory(): PartitionReaderFactory = { @@ -40,4 +40,10 @@ case class OrcScan( OrcPartitionReaderFactory(sparkSession.sessionState.conf, broadcastedConf, dataSchema, fileIndex.partitionSchema, readSchema) } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 80429d91d5e4d..1aec4d872a64d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -63,4 +63,10 @@ class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(optio } } } + + override def supportsDataType(dataType: DataType): Boolean = { + OrcDataSourceV2.supportsDataType(dataType) + } + + override def formatName: String = "ORC" } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 54342b691109d..591884095ec38 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -334,83 +334,97 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - // write path - Seq("csv", "json", "parquet", "orc").foreach { format => - var msg = intercept[AnalysisException] { - sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.contains("Cannot save interval data type into external storage.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new IntervalData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" } + def errorMessage(format: String, isWrite: Boolean): String = { + if (isWrite && (useV1 || format != "orc")) { + "cannot save interval data type into external storage." + } else { + s"$format data source does not support calendarinterval data type." + } + } + + withSQLConf(SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + // write path + Seq("csv", "json", "parquet", "orc").foreach { format => + var msg = intercept[AnalysisException] { + sql("select interval 1 days").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, true))) + } - // read path - Seq("parquet", "csv").foreach { format => - var msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support calendarinterval data type.")) + // read path + Seq("parquet", "csv").foreach { format => + var msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", CalendarIntervalType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new IntervalUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT).contains(errorMessage(format, false))) + } } } } } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - // TODO(SPARK-26744): support data type validating in V2 data source, and test V2 as well. - withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> "orc", - SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> "orc") { - withTempDir { dir => - val tempDir = new File(dir, "files").getCanonicalPath - - Seq("parquet", "csv", "orc").foreach { format => - // write path - var msg = intercept[AnalysisException] { - sql("select null").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - spark.udf.register("testType", () => new NullData()) - sql("select testType()").write.format(format).mode("overwrite").save(tempDir) - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - // read path - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", NullType, true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) - - msg = intercept[AnalysisException] { - val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) - spark.range(1).write.format(format).mode("overwrite").save(tempDir) - spark.read.schema(schema).format(format).load(tempDir).collect() - }.getMessage - assert(msg.toLowerCase(Locale.ROOT) - .contains(s"$format data source does not support null data type.")) + Seq(true, false).foreach { useV1 => + val useV1List = if (useV1) { + "orc" + } else { + "" + } + def errorMessage(format: String): String = { + s"$format data source does not support null data type." + } + withSQLConf(SQLConf.USE_V1_SOURCE_READER_LIST.key -> useV1List, + SQLConf.USE_V1_SOURCE_WRITER_LIST.key -> useV1List) { + withTempDir { dir => + val tempDir = new File(dir, "files").getCanonicalPath + + Seq("parquet", "csv", "orc").foreach { format => + // write path + var msg = intercept[AnalysisException] { + sql("select null").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + spark.udf.register("testType", () => new NullData()) + sql("select testType()").write.format(format).mode("overwrite").save(tempDir) + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + // read path + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", NullType, true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + + msg = intercept[AnalysisException] { + val schema = StructType(StructField("a", new NullUDT(), true) :: Nil) + spark.range(1).write.format(format).mode("overwrite").save(tempDir) + spark.read.schema(schema).format(format).load(tempDir).collect() + }.getMessage + assert(msg.toLowerCase(Locale.ROOT) + .contains(errorMessage(format))) + } } } } From d2f0dd55687329f2f9c348462956b628b7eed49f Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 3 Mar 2019 22:20:31 -0800 Subject: [PATCH 11/70] [SPARK-26956][SS] remove streaming output mode from data source v2 APIs ## What changes were proposed in this pull request? Similar to `SaveMode`, we should remove streaming `OutputMode` from data source v2 API, and use operations that has clear semantic. The changes are: 1. append mode: create `StreamingWrite` directly. By default, the `WriteBuilder` will create `Write` to append data. 2. complete mode: call `SupportsTruncate#truncate`. Complete mode means truncating all the old data and appending new data of the current epoch. `SupportsTruncate` has exactly the same semantic. 3. update mode: fail. The current streaming framework can't propagate the update keys, so v2 sinks are not able to implement update mode. In the future we can introduce a `SupportsUpdate` trait. The behavior changes: 1. all the v2 sinks(foreach, console, memory, kafka, noop) don't support update mode. The fact is, previously all the v2 sinks implement the update mode wrong. None of them can really support it. 2. kafka sink doesn't support complete mode. The fact is, the kafka sink can only append data. ## How was this patch tested? existing tests Closes #23859 from cloud-fan/update. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../sql/kafka010/KafkaSourceProvider.scala | 6 +-- .../writer/streaming/SupportsOutputMode.java | 29 ----------- .../datasources/noop/NoopDataSource.scala | 7 ++- .../streaming/MicroBatchExecution.scala | 10 +--- .../execution/streaming/StreamExecution.scala | 34 +++++++++++++ .../sql/execution/streaming/console.scala | 10 ++-- .../continuous/ContinuousExecution.scala | 10 +--- .../sources/ForeachWriterTable.scala | 11 ++-- .../streaming/sources/memoryV2.scala | 51 ++++++++----------- .../streaming/MemorySinkV2Suite.scala | 4 +- 10 files changed, 75 insertions(+), 97 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6994517b27d6a..01bb1536aa6c5 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.sources.v2.writer.WriteBuilder -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType @@ -362,7 +362,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { - new WriteBuilder with SupportsOutputMode { + new WriteBuilder { private var inputSchema: StructType = _ override def withInputDataSchema(schema: StructType): WriteBuilder = { @@ -370,8 +370,6 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister this } - override def outputMode(mode: OutputMode): WriteBuilder = this - override def buildForStreaming(): StreamingWrite = { import scala.collection.JavaConverters._ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java deleted file mode 100644 index 832dcfa145d1b..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/SupportsOutputMode.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer.streaming; - -import org.apache.spark.annotation.Unstable; -import org.apache.spark.sql.sources.v2.writer.WriteBuilder; -import org.apache.spark.sql.streaming.OutputMode; - -// TODO: remove it when we have `SupportsTruncate` -@Unstable -public interface SupportsOutputMode extends WriteBuilder { - - WriteBuilder outputMode(OutputMode mode); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 8f2072c586a94..22a74e3ccaeee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,8 +22,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -42,9 +41,9 @@ private[noop] object NoopTable extends Table with SupportsBatchWrite with Suppor } private[noop] object NoopWriteBuilder extends WriteBuilder - with SupportsSaveMode with SupportsOutputMode { + with SupportsSaveMode with SupportsTruncate { override def mode(mode: SaveMode): WriteBuilder = this - override def outputMode(mode: OutputMode): WriteBuilder = this + override def truncate(): WriteBuilder = this override def buildForBatch(): BatchWrite = NoopBatchWrite override def buildForStreaming(): StreamingWrite = NoopStreamingWrite } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index cca279030dfa7..de7cbe25ceb3b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -31,7 +31,6 @@ import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateCo import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} -import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -515,14 +514,7 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan case s: SupportsStreamingWrite => - // TODO: we should translate OutputMode to concrete write actions like truncate, but - // the truncate action is being developed in SPARK-26666. - val writeBuilder = s.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) - .withQueryId(runId.toString) - .withInputDataSchema(newAttributePlan.schema) - val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] - .outputMode(outputMode) - .buildForStreaming() + val streamingWrite = createStreamingWrite(s, extraOptions, newAttributePlan) WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, streamingWrite), newAttributePlan) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 90f7b477103ae..ea522ec90b0a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -24,6 +24,7 @@ import java.util.concurrent.{CountDownLatch, ExecutionException, TimeUnit} import java.util.concurrent.atomic.AtomicReference import java.util.concurrent.locks.{Condition, ReentrantLock} +import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import scala.util.control.NonFatal @@ -34,10 +35,14 @@ import org.apache.spark.{SparkContext, SparkException} import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} +import org.apache.spark.sql.sources.v2.writer.SupportsTruncate +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} @@ -532,6 +537,35 @@ abstract class StreamExecution( Option(name).map(_ + "
").getOrElse("") + s"id = $id
runId = $runId
batch = $batchDescription" } + + protected def createStreamingWrite( + table: SupportsStreamingWrite, + options: Map[String, String], + inputPlan: LogicalPlan): StreamingWrite = { + val writeBuilder = table.newWriteBuilder(new DataSourceOptions(options.asJava)) + .withQueryId(runId.toString) + .withInputDataSchema(inputPlan.schema) + outputMode match { + case Append => + writeBuilder.buildForStreaming() + + case Complete => + // TODO: we should do this check earlier when we have capability API. + require(writeBuilder.isInstanceOf[SupportsTruncate], + table.name + " does not support Complete mode.") + writeBuilder.asInstanceOf[SupportsTruncate].truncate().buildForStreaming() + + case Update => + // Although no v2 sinks really support Update mode now, but during tests we do want them + // to pretend to support Update mode, and treat Update mode same as Append mode. + if (Utils.isTesting) { + writeBuilder.buildForStreaming() + } else { + throw new IllegalArgumentException( + "Data source v2 streaming sinks does not support Update mode.") + } + } + } } object StreamExecution { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 348bc767b2c46..923bd749b29b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -21,9 +21,8 @@ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.writer.WriteBuilder -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingWrite, SupportsOutputMode} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.{SupportsTruncate, WriteBuilder} +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.types.StructType case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) @@ -64,7 +63,7 @@ object ConsoleTable extends Table with SupportsStreamingWrite { override def schema(): StructType = StructType(Nil) override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { - new WriteBuilder with SupportsOutputMode { + new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ override def withInputDataSchema(schema: StructType): WriteBuilder = { @@ -72,7 +71,8 @@ object ConsoleTable extends Table with SupportsStreamingWrite { this } - override def outputMode(mode: OutputMode): WriteBuilder = this + // Do nothing for truncate. Console sink is special that it just prints all the records. + override def truncate(): WriteBuilder = this override def buildForStreaming(): StreamingWrite = { assert(inputSchema != null) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index 20101c7fda320..a1ac55ca4ce25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsContinuousRead, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} -import org.apache.spark.sql.sources.v2.writer.streaming.SupportsOutputMode import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -175,14 +174,7 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - // TODO: we should translate OutputMode to concrete write actions like truncate, but - // the truncate action is being developed in SPARK-26666. - val writeBuilder = sink.newWriteBuilder(new DataSourceOptions(extraOptions.asJava)) - .withQueryId(runId.toString) - .withInputDataSchema(withNewSources.schema) - val streamingWrite = writeBuilder.asInstanceOf[SupportsOutputMode] - .outputMode(outputMode) - .buildForStreaming() + val streamingWrite = createStreamingWrite(sink, extraOptions, withNewSources) val planWithSink = WriteToContinuousDataSource(streamingWrite, withNewSources) reportTimeTaken("queryPlanning") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 6fbb59c43625a..c0ae44a128ca1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, Table} -import org.apache.spark.sql.sources.v2.writer.{DataWriter, WriteBuilder, WriterCommitMessage} -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -46,7 +45,7 @@ case class ForeachWriterTable[T]( override def schema(): StructType = StructType(Nil) override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { - new WriteBuilder with SupportsOutputMode { + new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ override def withInputDataSchema(schema: StructType): WriteBuilder = { @@ -54,7 +53,9 @@ case class ForeachWriterTable[T]( this } - override def outputMode(mode: OutputMode): WriteBuilder = this + // Do nothing for truncate. Foreach sink is special that it just forwards all the records to + // ForeachWriter. + override def truncate(): WriteBuilder = this override def buildForStreaming(): StreamingWrite = { new StreamingWrite { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 3fc2cbe0fde57..397c5ff0dcb6a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -30,12 +30,10 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes.{Append, Complete, Update} import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.writer._ -import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite, SupportsOutputMode} -import org.apache.spark.sql.streaming.OutputMode +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType /** @@ -49,12 +47,12 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi override def schema(): StructType = StructType(Nil) override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { - new WriteBuilder with SupportsOutputMode { - private var mode: OutputMode = _ + new WriteBuilder with SupportsTruncate { + private var needTruncate: Boolean = false private var inputSchema: StructType = _ - override def outputMode(mode: OutputMode): WriteBuilder = { - this.mode = mode + override def truncate(): WriteBuilder = { + this.needTruncate = true this } @@ -64,7 +62,7 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi } override def buildForStreaming(): StreamingWrite = { - new MemoryStreamingWrite(MemorySinkV2.this, mode, inputSchema) + new MemoryStreamingWrite(MemorySinkV2.this, inputSchema, needTruncate) } } } @@ -101,27 +99,20 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi }.mkString("\n") } - def write(batchId: Long, outputMode: OutputMode, newRows: Array[Row]): Unit = { + def write(batchId: Long, needTruncate: Boolean, newRows: Array[Row]): Unit = { val notCommitted = synchronized { latestBatchId.isEmpty || batchId > latestBatchId.get } if (notCommitted) { logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, newRows) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, newRows) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySinkV2") + val rows = AddedData(batchId, newRows) + if (needTruncate) { + synchronized { + batches.clear() + batches += rows + } + } else { + synchronized { batches += rows } } } else { logDebug(s"Skipping already committed batch: $batchId") @@ -139,18 +130,18 @@ case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryStreamingWrite( - val sink: MemorySinkV2, outputMode: OutputMode, schema: StructType) + val sink: MemorySinkV2, schema: StructType, needTruncate: Boolean) extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { - MemoryWriterFactory(outputMode, schema) + MemoryWriterFactory(schema) } override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { val newRows = messages.flatMap { case message: MemoryWriterCommitMessage => message.data } - sink.write(epochId, outputMode, newRows) + sink.write(epochId, needTruncate, newRows) } override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { @@ -158,13 +149,13 @@ class MemoryStreamingWrite( } } -case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) +case class MemoryWriterFactory(schema: StructType) extends DataWriterFactory with StreamingDataWriterFactory { override def createWriter( partitionId: Int, taskId: Long): DataWriter[InternalRow] = { - new MemoryDataWriter(partitionId, outputMode, schema) + new MemoryDataWriter(partitionId, schema) } override def createWriter( @@ -175,7 +166,7 @@ case class MemoryWriterFactory(outputMode: OutputMode, schema: StructType) } } -class MemoryDataWriter(partition: Int, outputMode: OutputMode, schema: StructType) +class MemoryDataWriter(partition: Int, schema: StructType) extends DataWriter[InternalRow] with Logging { private val data = mutable.Buffer[Row]() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala index e804377540517..a90acf85c0161 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala @@ -29,7 +29,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("data writer") { val partition = 1234 val writer = new MemoryDataWriter( - partition, OutputMode.Append(), new StructType().add("i", "int")) + partition, new StructType().add("i", "int")) writer.write(InternalRow(1)) writer.write(InternalRow(2)) writer.write(InternalRow(44)) @@ -44,7 +44,7 @@ class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { test("streaming writer") { val sink = new MemorySinkV2 val write = new MemoryStreamingWrite( - sink, OutputMode.Append(), new StructType().add("i", "int")) + sink, new StructType().add("i", "int"), needTruncate = false) write.commit(0, Array( MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), From 49dd06704538c76dbd253237a53a11f88623c7f0 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Fri, 8 Feb 2019 10:22:51 -0800 Subject: [PATCH 12/70] [SPARK-26389][SS] Add force delete temp checkpoint configuration ## What changes were proposed in this pull request? Not all users wants to keep temporary checkpoint directories. Additionally hard to restore from it. In this PR I've added a force delete flag which is default `false`. Additionally not clear for users when temporary checkpoint directory deleted so added log messages to explain this a bit more. ## How was this patch tested? Existing + additional unit tests. Closes #23732 from gaborgsomogyi/SPARK-26389. Authored-by: Gabor Somogyi Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/internal/SQLConf.scala | 6 ++++ .../execution/streaming/StreamExecution.scala | 11 +++++-- .../sql/streaming/StreamingQueryManager.scala | 8 +++-- .../test/DataStreamReaderWriterSuite.scala | 32 ++++++++++++++++++- 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 714dc6cda578d..cdee09fead3c3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -918,6 +918,12 @@ object SQLConf { .stringConf .createOptional + val FORCE_DELETE_TEMP_CHECKPOINT_LOCATION = + buildConf("spark.sql.streaming.forceDeleteTempCheckpointLocation") + .doc("When true, enable temporary checkpoint locations force delete.") + .booleanConf + .createWithDefault(false) + val MIN_BATCHES_TO_RETAIN = buildConf("spark.sql.streaming.minBatchesToRetain") .internal() .doc("The minimum number of batches that must be retained and made recoverable.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index ea522ec90b0a1..3ad417c53da37 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -60,7 +60,8 @@ case object RECONFIGURING extends State * and the results are committed transactionally to the given [[Sink]]. * * @param deleteCheckpointOnStop whether to delete the checkpoint if the query is stopped without - * errors + * errors. Checkpoint deletion can be forced with the appropriate + * Spark configuration. */ abstract class StreamExecution( override val sparkSession: SparkSession, @@ -97,6 +98,7 @@ abstract class StreamExecution( fs.mkdirs(checkpointPath) checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString } + logInfo(s"Checkpoint root $checkpointRoot resolved to $resolvedCheckpointRoot.") def logicalPlan: LogicalPlan @@ -340,10 +342,13 @@ abstract class StreamExecution( postEvent( new QueryTerminatedEvent(id, runId, exception.map(_.cause).map(Utils.exceptionString))) - // Delete the temp checkpoint only when the query didn't fail - if (deleteCheckpointOnStop && exception.isEmpty) { + // Delete the temp checkpoint when either force delete enabled or the query didn't fail + if (deleteCheckpointOnStop && + (sparkSession.sessionState.conf + .getConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION) || exception.isEmpty)) { val checkpointPath = new Path(resolvedCheckpointRoot) try { + logInfo(s"Deleting checkpoint $checkpointPath.") val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) fs.delete(checkpointPath, true) } catch { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index e6773c5cc3bd4..59fae370d7ff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -221,9 +221,13 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } }.getOrElse { if (useTempCheckpointLocation) { - // Delete the temp checkpoint when a query is being stopped without errors. deleteCheckpointOnStop = true - Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + val tempDir = Utils.createTempDir(namePrefix = s"temporary").getCanonicalPath + logWarning("Temporary checkpoint location created which is deleted normally when" + + s" the query didn't fail: $tempDir. If it's required to delete it under any" + + s" circumstances, please set ${SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key} to" + + s" true. Important to know deleting temp checkpoint folder is best effort.") + tempDir } else { throw new AnalysisException( "checkpointLocation must be specified either " + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 74ea0bfacba54..c3c7dcbaaece7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -614,6 +614,21 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } } + test("configured checkpoint dir should not be deleted if a query is stopped without errors and" + + " force temp checkpoint deletion enabled") { + import testImplicits._ + withTempDir { checkpointPath => + withSQLConf(SQLConf.CHECKPOINT_LOCATION.key -> checkpointPath.getAbsolutePath, + SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") { + val ds = MemoryStream[Int].toDS + val query = ds.writeStream.format("console").start() + assert(checkpointPath.exists()) + query.stop() + assert(checkpointPath.exists()) + } + } + } + test("temp checkpoint dir should be deleted if a query is stopped without errors") { import testImplicits._ val query = MemoryStream[Int].toDS.writeStream.format("console").start() @@ -627,6 +642,17 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } testQuietly("temp checkpoint dir should not be deleted if a query is stopped with an error") { + testTempCheckpointWithFailedQuery(false) + } + + testQuietly("temp checkpoint should be deleted if a query is stopped with an error and force" + + " temp checkpoint deletion enabled") { + withSQLConf(SQLConf.FORCE_DELETE_TEMP_CHECKPOINT_LOCATION.key -> "true") { + testTempCheckpointWithFailedQuery(true) + } + } + + private def testTempCheckpointWithFailedQuery(checkpointMustBeDeleted: Boolean): Unit = { import testImplicits._ val input = MemoryStream[Int] val query = input.toDS.map(_ / 0).writeStream.format("console").start() @@ -638,7 +664,11 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { intercept[StreamingQueryException] { query.awaitTermination() } - assert(fs.exists(checkpointDir)) + if (!checkpointMustBeDeleted) { + assert(fs.exists(checkpointDir)) + } else { + assert(!fs.exists(checkpointDir)) + } } test("SPARK-20431: Specify a schema by using a DDL-formatted string") { From 1f5d3d41ee2c3ef3ed647f427e6976032860771a Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Wed, 20 Feb 2019 15:44:20 -0800 Subject: [PATCH 13/70] [SPARK-26824][SS] Fix the checkpoint location and _spark_metadata when it contains special chars MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What changes were proposed in this pull request? When a user specifies a checkpoint location or a file sink output using a path containing special chars that need to be escaped in a path, the streaming query will store checkpoint and file sink metadata in a wrong place. In this PR, I uploaded a checkpoint that was generated by the following codes using Spark 2.4.0 to show this issue: ``` implicit val s = spark.sqlContext val input = org.apache.spark.sql.execution.streaming.MemoryStream[Int] input.addData(1, 2, 3) val q = input.toDF.writeStream.format("parquet").option("checkpointLocation", ".../chk %#chk").start(".../output %#output") q.stop() ``` Here is the structure of the directory: ``` sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0 ├── chk%252520%252525%252523chk │   ├── commits │   │   └── 0 │   ├── metadata │   └── offsets │   └── 0 ├── output %#output │   └── part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet └── output%20%25%23output └── _spark_metadata └── 0 ``` In this checkpoint, the user specified checkpoint location is `.../chk %#chk` but the real path to store the checkpoint is `.../chk%252520%252525%252523chk` (this is generated by escaping the original path three times). The user specified output path is `.../output %#output` but the path to store `_spark_metadata` is `.../output%20%25%23output/_spark_metadata` (this is generated by escaping the original path once). The data files are still in the correct path (such as `.../output %#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet`). This checkpoint will be used in unit tests in this PR. The fix is just simply removing improper `Path.toUri` calls to fix the issue. However, as the user may not read the release note and is not aware of this checkpoint location change, if they upgrade Spark without moving checkpoint to the new location, their query will just start from the scratch. In order to not surprise the users, this PR also adds a check to **detect the impacted paths and throws an error** to include the migration guide. This check can be turned off by an internal sql conf `spark.sql.streaming.checkpoint.escapedPathCheck.enabled`. Here are examples of errors that will be reported: - Streaming checkpoint error: ``` Error: we detected a possible problem with the location of your checkpoint and you likely need to move it before restarting this query. Earlier version of Spark incorrectly escaped paths when writing out checkpoints for structured streaming. While this was corrected in Spark 3.0, it appears that your query was started using an earlier version that incorrectly handled the checkpoint path. Correct Checkpoint Directory: /.../chk %#chk Incorrect Checkpoint Directory: /.../chk%252520%252525%252523chk Please move the data from the incorrect directory to the correct one, delete the incorrect directory, and then restart this query. If you believe you are receiving this message in error, you can disable it with the SQL conf spark.sql.streaming.checkpoint.escapedPathCheck.enabled. ``` - File sink error (`_spark_metadata`): ``` Error: we detected a possible problem with the location of your "_spark_metadata" directory and you likely need to move it before restarting this query. Earlier version of Spark incorrectly escaped paths when writing out the "_spark_metadata" directory for structured streaming. While this was corrected in Spark 3.0, it appears that your query was started using an earlier version that incorrectly handled the "_spark_metadata" path. Correct "_spark_metadata" Directory: /.../output %#output/_spark_metadata Incorrect "_spark_metadata" Directory: /.../output%20%25%23output/_spark_metadata Please move the data from the incorrect directory to the correct one, delete the incorrect directory, and then restart this query. If you believe you are receiving this message in error, you can disable it with the SQL conf spark.sql.streaming.checkpoint.escapedPathCheck.enabled. ``` ## How was this patch tested? The new unit tests. Closes #23733 from zsxwing/path-fix. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../apache/spark/sql/internal/SQLConf.scala | 8 + .../execution/datasources/DataSource.scala | 3 +- .../execution/streaming/FileStreamSink.scala | 66 +++++-- .../streaming/FileStreamSource.scala | 4 +- .../streaming/MetadataLogFileIndex.scala | 10 +- .../execution/streaming/StreamExecution.scala | 48 ++++- .../sql/streaming/StreamingQueryManager.scala | 4 +- .../chk%252520%252525@%252523chk/commits/0 | 2 + .../chk%252520%252525@%252523chk/metadata | 1 + .../chk%252520%252525@%252523chk/offsets/0 | 3 + ...4201-8245-05f3dae4c372-c000.snappy.parquet | Bin 0 -> 404 bytes .../output%20%25@%23output/_spark_metadata/0 | 2 + .../sql/streaming/FileStreamSinkSuite.scala | 24 +++ .../sql/streaming/StreamingQuerySuite.scala | 186 ++++++++++++++++++ .../test/DataStreamReaderWriterSuite.scala | 8 +- 15 files changed, 341 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 create mode 100644 sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata create mode 100644 sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 create mode 100644 sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet create mode 100644 sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index cdee09fead3c3..56fbeb6b4f798 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1123,6 +1123,14 @@ object SQLConf { .internal() .stringConf + val STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED = + buildConf("spark.sql.streaming.checkpoint.escapedPathCheck.enabled") + .doc("Whether to detect a streaming query may pick up an incorrect checkpoint path due " + + "to SPARK-26824.") + .internal() + .booleanConf + .createWithDefault(true) + val PARALLEL_FILE_LISTING_IN_STATS_COMPUTATION = buildConf("spark.sql.statistics.parallelFileListingInStatsComputation.enabled") .internal() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala index b0548bc21156e..622ad3b559ebd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSource.scala @@ -348,7 +348,8 @@ case class DataSource( case (format: FileFormat, _) if FileStreamSink.hasMetadata( caseInsensitiveOptions.get("path").toSeq ++ paths, - sparkSession.sessionState.newHadoopConf()) => + sparkSession.sessionState.newHadoopConf(), + sparkSession.sessionState.conf) => val basePath = new Path((caseInsensitiveOptions.get("path").toSeq ++ paths).head) val fileCatalog = new MetadataLogFileIndex(sparkSession, basePath, userSpecifiedSchema) val dataSchema = userSpecifiedSchema.orElse { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala index b3d12f67b5d63..b679f163fc561 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSink.scala @@ -20,13 +20,15 @@ package org.apache.spark.sql.execution.streaming import scala.util.control.NonFatal import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.Path +import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.internal.io.FileCommitProtocol import org.apache.spark.sql.{DataFrame, SparkSession} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, FileFormat, FileFormatWriter} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.util.SerializableConfiguration object FileStreamSink extends Logging { @@ -37,23 +39,54 @@ object FileStreamSink extends Logging { * Returns true if there is a single path that has a metadata log indicating which files should * be read. */ - def hasMetadata(path: Seq[String], hadoopConf: Configuration): Boolean = { + def hasMetadata(path: Seq[String], hadoopConf: Configuration, sqlConf: SQLConf): Boolean = { path match { case Seq(singlePath) => + val hdfsPath = new Path(singlePath) + val fs = hdfsPath.getFileSystem(hadoopConf) + if (fs.isDirectory(hdfsPath)) { + val metadataPath = new Path(hdfsPath, metadataDir) + checkEscapedMetadataPath(fs, metadataPath, sqlConf) + fs.exists(metadataPath) + } else { + false + } + case _ => false + } + } + + def checkEscapedMetadataPath(fs: FileSystem, metadataPath: Path, sqlConf: SQLConf): Unit = { + if (sqlConf.getConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED) + && StreamExecution.containsSpecialCharsInPath(metadataPath)) { + val legacyMetadataPath = new Path(metadataPath.toUri.toString) + val legacyMetadataPathExists = try { - val hdfsPath = new Path(singlePath) - val fs = hdfsPath.getFileSystem(hadoopConf) - if (fs.isDirectory(hdfsPath)) { - fs.exists(new Path(hdfsPath, metadataDir)) - } else { - false - } + fs.exists(legacyMetadataPath) } catch { case NonFatal(e) => - logWarning(s"Error while looking for metadata directory.") + // We may not have access to this directory. Don't fail the query if that happens. + logWarning(e.getMessage, e) false } - case _ => false + if (legacyMetadataPathExists) { + throw new SparkException( + s"""Error: we detected a possible problem with the location of your "_spark_metadata" + |directory and you likely need to move it before restarting this query. + | + |Earlier version of Spark incorrectly escaped paths when writing out the + |"_spark_metadata" directory for structured streaming. While this was corrected in + |Spark 3.0, it appears that your query was started using an earlier version that + |incorrectly handled the "_spark_metadata" path. + | + |Correct "_spark_metadata" Directory: $metadataPath + |Incorrect "_spark_metadata" Directory: $legacyMetadataPath + | + |Please move the data from the incorrect directory to the correct one, delete the + |incorrect directory, and then restart this query. If you believe you are receiving + |this message in error, you can disable it with the SQL conf + |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}.""" + .stripMargin) + } } } @@ -92,11 +125,16 @@ class FileStreamSink( partitionColumnNames: Seq[String], options: Map[String, String]) extends Sink with Logging { + private val hadoopConf = sparkSession.sessionState.newHadoopConf() private val basePath = new Path(path) - private val logPath = new Path(basePath, FileStreamSink.metadataDir) + private val logPath = { + val metadataDir = new Path(basePath, FileStreamSink.metadataDir) + val fs = metadataDir.getFileSystem(hadoopConf) + FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf) + metadataDir + } private val fileLog = - new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toUri.toString) - private val hadoopConf = sparkSession.sessionState.newHadoopConf() + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, logPath.toString) private def basicWriteJobStatsTracker: BasicWriteJobStatsTracker = { val serializableHadoopConf = new SerializableConfiguration(hadoopConf) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala index 103fa7ce9066d..43b70ae0a51b1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/FileStreamSource.scala @@ -208,7 +208,7 @@ class FileStreamSource( var allFiles: Seq[FileStatus] = null sourceHasMetadata match { case None => - if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) { sourceHasMetadata = Some(true) allFiles = allFilesUsingMetadataLogFileIndex() } else { @@ -220,7 +220,7 @@ class FileStreamSource( // double check whether source has metadata, preventing the extreme corner case that // metadata log and data files are only generated after the previous // `FileStreamSink.hasMetadata` check - if (FileStreamSink.hasMetadata(Seq(path), hadoopConf)) { + if (FileStreamSink.hasMetadata(Seq(path), hadoopConf, sparkSession.sessionState.conf)) { sourceHasMetadata = Some(true) allFiles = allFilesUsingMetadataLogFileIndex() } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala index 5cacdd070b735..80eed7b277216 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MetadataLogFileIndex.scala @@ -39,10 +39,16 @@ class MetadataLogFileIndex( userSpecifiedSchema: Option[StructType]) extends PartitioningAwareFileIndex(sparkSession, Map.empty, userSpecifiedSchema) { - private val metadataDirectory = new Path(path, FileStreamSink.metadataDir) + private val metadataDirectory = { + val metadataDir = new Path(path, FileStreamSink.metadataDir) + val fs = metadataDir.getFileSystem(sparkSession.sessionState.newHadoopConf()) + FileStreamSink.checkEscapedMetadataPath(fs, metadataDir, sparkSession.sessionState.conf) + metadataDir + } + logInfo(s"Reading streaming file log from $metadataDirectory") private val metadataLog = - new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toUri.toString) + new FileStreamSinkLog(FileStreamSinkLog.VERSION, sparkSession, metadataDirectory.toString) private val allFilesFromLog = metadataLog.allFiles().map(_.toFileStatus).filterNot(_.isDirectory) private var cachedPartitionSpec: PartitionSpec = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 3ad417c53da37..bba640eea7e5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -95,8 +95,45 @@ abstract class StreamExecution( val resolvedCheckpointRoot = { val checkpointPath = new Path(checkpointRoot) val fs = checkpointPath.getFileSystem(sparkSession.sessionState.newHadoopConf()) - fs.mkdirs(checkpointPath) - checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory).toUri.toString + if (sparkSession.conf.get(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED) + && StreamExecution.containsSpecialCharsInPath(checkpointPath)) { + // In Spark 2.4 and earlier, the checkpoint path is escaped 3 times (3 `Path.toUri.toString` + // calls). If this legacy checkpoint path exists, we will throw an error to tell the user how + // to migrate. + val legacyCheckpointDir = + new Path(new Path(checkpointPath.toUri.toString).toUri.toString).toUri.toString + val legacyCheckpointDirExists = + try { + fs.exists(new Path(legacyCheckpointDir)) + } catch { + case NonFatal(e) => + // We may not have access to this directory. Don't fail the query if that happens. + logWarning(e.getMessage, e) + false + } + if (legacyCheckpointDirExists) { + throw new SparkException( + s"""Error: we detected a possible problem with the location of your checkpoint and you + |likely need to move it before restarting this query. + | + |Earlier version of Spark incorrectly escaped paths when writing out checkpoints for + |structured streaming. While this was corrected in Spark 3.0, it appears that your + |query was started using an earlier version that incorrectly handled the checkpoint + |path. + | + |Correct Checkpoint Directory: $checkpointPath + |Incorrect Checkpoint Directory: $legacyCheckpointDir + | + |Please move the data from the incorrect directory to the correct one, delete the + |incorrect directory, and then restart this query. If you believe you are receiving + |this message in error, you can disable it with the SQL conf + |${SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key}.""" + .stripMargin) + } + } + val checkpointDir = checkpointPath.makeQualified(fs.getUri, fs.getWorkingDirectory) + fs.mkdirs(checkpointDir) + checkpointDir.toString } logInfo(s"Checkpoint root $checkpointRoot resolved to $resolvedCheckpointRoot.") @@ -232,7 +269,7 @@ abstract class StreamExecution( /** Returns the path of a file with `name` in the checkpoint directory. */ protected def checkpointFile(name: String): String = - new Path(new Path(resolvedCheckpointRoot), name).toUri.toString + new Path(new Path(resolvedCheckpointRoot), name).toString /** * Starts the execution. This returns only after the thread has started and [[QueryStartedEvent]] @@ -607,6 +644,11 @@ object StreamExecution { case _ => false } + + /** Whether the path contains special chars that will be escaped when converting to a `URI`. */ + def containsSpecialCharsInPath(path: Path): Boolean = { + path.toUri.getPath != new Path(path.toUri.toString).toUri.getPath + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 59fae370d7ff4..9d86cac9cec5b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -214,10 +214,10 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo triggerClock: Clock): StreamingQueryWrapper = { var deleteCheckpointOnStop = false val checkpointLocation = userSpecifiedCheckpointLocation.map { userSpecified => - new Path(userSpecified).toUri.toString + new Path(userSpecified).toString }.orElse { df.sparkSession.sessionState.conf.checkpointLocation.map { location => - new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toUri.toString + new Path(location, userSpecifiedName.getOrElse(UUID.randomUUID().toString)).toString } }.getOrElse { if (useTempCheckpointLocation) { diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 new file mode 100644 index 0000000000000..9c1e3021c3ead --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/commits/0 @@ -0,0 +1,2 @@ +v1 +{"nextBatchWatermarkMs":0} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata new file mode 100644 index 0000000000000..3071b0dfc550b --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/metadata @@ -0,0 +1 @@ +{"id":"09be7fb3-49d8-48a6-840d-e9c2ad92a898"} \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 new file mode 100644 index 0000000000000..a0a567631fd14 --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/chk%252520%252525@%252523chk/offsets/0 @@ -0,0 +1,3 @@ +v1 +{"batchWatermarkMs":0,"batchTimestampMs":1549649384149,"conf":{"spark.sql.streaming.stateStore.providerClass":"org.apache.spark.sql.execution.streaming.state.HDFSBackedStateStoreProvider","spark.sql.streaming.flatMapGroupsWithState.stateFormatVersion":"2","spark.sql.streaming.multipleWatermarkPolicy":"min","spark.sql.streaming.aggregation.stateFormatVersion":"2","spark.sql.shuffle.partitions":"200"}} +0 \ No newline at end of file diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output %@#output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..1b2919b25c381a343c47e0704e84c845700e3a76 GIT binary patch literal 404 zcmaKp&uhXk6vwm9l^(pNB#=W1C<Mc zwhCC{j1|q!2;lEJ-3q(;5|wUMp;|?j2i|^fHJ|gQwO!uwkgrN@iiXeM4!l;?sdSza zDi>#2w|cED2z6(S$_#i`$}~FUzkT`qa6B%Lh&*4Yx0Ma{+BW5;4k8grA%juYm8J+} s5~=BQl1h26lWZ2tSjs3?(JW2Ud@_r(;x5ha5;dlb!HzHRvJ3c+AJ?T>vj6}9 literal 0 HcmV?d00001 diff --git a/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 new file mode 100644 index 0000000000000..79768f89d6eca --- /dev/null +++ b/sql/core/src/test/resources/structured-streaming/escaped-path-2.4.0/output%20%25@%23output/_spark_metadata/0 @@ -0,0 +1,2 @@ +v1 +{"path":"file://TEMPDIR/output%20%25@%23output/part-00000-97f675a2-bb82-4201-8245-05f3dae4c372-c000.snappy.parquet","size":404,"isDir":false,"modificationTime":1549649385000,"blockReplication":1,"blockSize":33554432,"action":"add"} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala index ed53def556cb8..619d118e20873 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSinkSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.Locale import org.apache.hadoop.fs.Path @@ -454,4 +455,27 @@ class FileStreamSinkSuite extends StreamTest { } } } + + test("special characters in output path") { + withTempDir { tempDir => + val checkpointDir = new File(tempDir, "chk") + val outputDir = new File(tempDir, "output @#output") + val inputData = MemoryStream[Int] + inputData.addData(1, 2, 3) + val q = inputData.toDF() + .writeStream + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .format("parquet") + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + } finally { + q.stop() + } + // The "_spark_metadata" directory should be in "outputDir" + assert(outputDir.listFiles.map(_.getName).contains(FileStreamSink.metadataDir)) + val outputDf = spark.read.parquet(outputDir.getCanonicalPath).as[Int] + checkDatasetUnorderly(outputDf, 1, 2, 3) + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index dc22e31678fa3..729173cb7104f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -17,11 +17,14 @@ package org.apache.spark.sql.streaming +import java.io.File import java.util.concurrent.CountDownLatch import scala.collection.mutable +import org.apache.commons.io.{FileUtils, IOUtils} import org.apache.commons.lang3.RandomStringUtils +import org.apache.hadoop.fs.Path import org.scalactic.TolerantNumerics import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout @@ -915,6 +918,189 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi ) } + test("special characters in checkpoint path") { + withTempDir { tempDir => + val checkpointDir = new File(tempDir, "chk @#chk") + val inputData = MemoryStream[Int] + inputData.addData(1) + val q = inputData.toDF() + .writeStream + .format("noop") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start() + try { + q.processAllAvailable() + assert(checkpointDir.listFiles().toList.nonEmpty) + } finally { + q.stop() + } + } + } + + /** + * Copy the checkpoint generated by Spark 2.4.0 from test resource to `dir` to set up a legacy + * streaming checkpoint. + */ + private def setUp2dot4dot0Checkpoint(dir: File): Unit = { + val input = getClass.getResource("/structured-streaming/escaped-path-2.4.0") + assert(input != null, "cannot find test resource '/structured-streaming/escaped-path-2.4.0'") + val inputDir = new File(input.toURI) + + // Copy test files to tempDir so that we won't modify the original data. + FileUtils.copyDirectory(inputDir, dir) + + // Spark 2.4 and earlier escaped the _spark_metadata path once + val legacySparkMetadataDir = new File( + dir, + new Path("output %@#output/_spark_metadata").toUri.toString) + + // Migrate from legacy _spark_metadata directory to the new _spark_metadata directory. + // Ideally we should copy "_spark_metadata" directly like what the user is supposed to do to + // migrate to new version. However, in our test, "tempDir" will be different in each run and + // we need to fix the absolute path in the metadata to match "tempDir". + val sparkMetadata = FileUtils.readFileToString(new File(legacySparkMetadataDir, "0"), "UTF-8") + FileUtils.write( + new File(legacySparkMetadataDir, "0"), + sparkMetadata.replaceAll("TEMPDIR", dir.getCanonicalPath), + "UTF-8") + } + + test("detect escaped path and report the migration guide") { + // Assert that the error message contains the migration conf, path and the legacy path. + def assertMigrationError(errorMessage: String, path: File, legacyPath: File): Unit = { + Seq(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key, + path.getCanonicalPath, + legacyPath.getCanonicalPath).foreach { msg => + assert(errorMessage.contains(msg)) + } + } + + withTempDir { tempDir => + setUp2dot4dot0Checkpoint(tempDir) + + // Here are the paths we will use to create the query + val outputDir = new File(tempDir, "output %@#output") + val checkpointDir = new File(tempDir, "chk %@#chk") + val sparkMetadataDir = new File(tempDir, "output %@#output/_spark_metadata") + + // The escaped paths used by Spark 2.4 and earlier. + // Spark 2.4 and earlier escaped the checkpoint path three times + val legacyCheckpointDir = new File( + tempDir, + new Path(new Path(new Path("chk %@#chk").toUri.toString).toUri.toString).toUri.toString) + // Spark 2.4 and earlier escaped the _spark_metadata path once + val legacySparkMetadataDir = new File( + tempDir, + new Path("output %@#output/_spark_metadata").toUri.toString) + + // Reading a file sink output in a batch query should detect the legacy _spark_metadata + // directory and throw an error + val e = intercept[SparkException] { + spark.read.load(outputDir.getCanonicalPath).as[Int] + } + assertMigrationError(e.getMessage, sparkMetadataDir, legacySparkMetadataDir) + + // Restarting the streaming query should detect the legacy _spark_metadata directory and throw + // an error + val inputData = MemoryStream[Int] + val e2 = intercept[SparkException] { + inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + } + assertMigrationError(e2.getMessage, sparkMetadataDir, legacySparkMetadataDir) + + // Move "_spark_metadata" to fix the file sink and test the checkpoint path. + FileUtils.moveDirectory(legacySparkMetadataDir, sparkMetadataDir) + + // Restarting the streaming query should detect the legacy checkpoint path and throw an error + val e3 = intercept[SparkException] { + inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + } + assertMigrationError(e3.getMessage, checkpointDir, legacyCheckpointDir) + + // Fix the checkpoint path and verify that the user can migrate the issue by moving files. + FileUtils.moveDirectory(legacyCheckpointDir, checkpointDir) + + val q = inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + // Check the query id to make sure it did use checkpoint + assert(q.id.toString == "09be7fb3-49d8-48a6-840d-e9c2ad92a898") + + // Verify that the batch query can read "_spark_metadata" correctly after migration. + val df = spark.read.load(outputDir.getCanonicalPath) + assert(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex") + checkDatasetUnorderly(df.as[Int], 1, 2, 3) + } finally { + q.stop() + } + } + } + + test("ignore the escaped path check when the flag is off") { + withTempDir { tempDir => + setUp2dot4dot0Checkpoint(tempDir) + val outputDir = new File(tempDir, "output %@#output") + val checkpointDir = new File(tempDir, "chk %@#chk") + + withSQLConf(SQLConf.STREAMING_CHECKPOINT_ESCAPED_PATH_CHECK_ENABLED.key -> "false") { + // Verify that the batch query ignores the legacy "_spark_metadata" + val df = spark.read.load(outputDir.getCanonicalPath) + assert(!(df.queryExecution.executedPlan.toString contains "MetadataLogFileIndex")) + checkDatasetUnorderly(df.as[Int], 1, 2, 3) + + val inputData = MemoryStream[Int] + val q = inputData.toDF() + .writeStream + .format("parquet") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .start(outputDir.getCanonicalPath) + try { + q.processAllAvailable() + // Check the query id to make sure it ignores the legacy checkpoint + assert(q.id.toString != "09be7fb3-49d8-48a6-840d-e9c2ad92a898") + } finally { + q.stop() + } + } + } + } + + test("containsSpecialCharsInPath") { + Seq("foo/b ar", + "/foo/b ar", + "file:/foo/b ar", + "file://foo/b ar", + "file:///foo/b ar", + "file://foo:bar@bar/foo/b ar").foreach { p => + assert(StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p") + } + Seq("foo/bar", + "/foo/bar", + "file:/foo/bar", + "file://foo/bar", + "file:///foo/bar", + "file://foo:bar@bar/foo/bar", + // Special chars not in a path should not be considered as such urls won't hit the escaped + // path issue. + "file://foo:b ar@bar/foo/bar", + "file://foo:bar@b ar/foo/bar", + "file://f oo:bar@bar/foo/bar").foreach { p => + assert(!StreamExecution.containsSpecialCharsInPath(new Path(p)), s"failed to check $p") + } + } + /** Create a streaming DF that only execute one batch in which it returns the given static DF */ private def createSingleTriggerStreamingDF(triggerDF: DataFrame): DataFrame = { require(!triggerDF.isStreaming) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index c3c7dcbaaece7..99dc0769a3d69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -359,7 +359,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("source metadataPath") { LastOptions.clear() - val checkpointLocationURI = new Path(newMetadataDir).toUri + val checkpointLocation = new Path(newMetadataDir) val df1 = spark.readStream .format("org.apache.spark.sql.streaming.test") @@ -371,7 +371,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { val q = df1.union(df2).writeStream .format("org.apache.spark.sql.streaming.test") - .option("checkpointLocation", checkpointLocationURI.toString) + .option("checkpointLocation", checkpointLocation.toString) .trigger(ProcessingTime(10.seconds)) .start() q.processAllAvailable() @@ -379,14 +379,14 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/0"), + meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/0"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) verify(LastOptions.mockStreamSourceProvider).createSource( any(), - meq(s"${makeQualifiedPath(checkpointLocationURI.toString)}/sources/1"), + meq(s"${new Path(makeQualifiedPath(checkpointLocation.toString)).toString}/sources/1"), meq(None), meq("org.apache.spark.sql.streaming.test"), meq(Map.empty)) From 982df040fae7c655c44c9e4e6e09179a81f39021 Mon Sep 17 00:00:00 2001 From: Shixiong Zhu Date: Sat, 9 Mar 2019 14:26:58 -0800 Subject: [PATCH 14/70] [SPARK-27111][SS] Fix a race that a continuous query may fail with InterruptedException ## What changes were proposed in this pull request? Before a Kafka consumer gets assigned with partitions, its offset will contain 0 partitions. However, runContinuous will still run and launch a Spark job having 0 partitions. In this case, there is a race that epoch may interrupt the query execution thread after `lastExecution.toRdd`, and either `epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)` or the next `runContinuous` will get interrupted unintentionally. To handle this case, this PR has the following changes: - Clean up the resources in `queryExecutionThread.runUninterruptibly`. This may increase the waiting time of `stop` but should be minor because the operations here are very fast (just sending an RPC message in the same process and stopping a very simple thread). - Clear the interrupted status at the end so that it won't impact the `runContinuous` call. We may clear the interrupted status set by `stop`, but it doesn't affect the query termination because `runActivatedStream` will check `state` and exit accordingly. I also updated the clean up codes to make sure exceptions thrown from `epochEndpoint.askSync[Unit](StopContinuousExecutionWrites)` won't stop the clean up. ## How was this patch tested? Jenkins Closes #24034 from zsxwing/SPARK-27111. Authored-by: Shixiong Zhu Signed-off-by: Shixiong Zhu --- .../continuous/ContinuousExecution.scala | 30 ++++++++++++++----- 1 file changed, 23 insertions(+), 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index a1ac55ca4ce25..c2aab88b1b8da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -259,13 +259,29 @@ class ContinuousExecution( logInfo(s"Query $id ignoring exception from reconfiguring: $t") // interrupted by reconfiguration - swallow exception so we can restart the query } finally { - epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) - SparkEnv.get.rpcEnv.stop(epochEndpoint) - - epochUpdateThread.interrupt() - epochUpdateThread.join() - - sparkSession.sparkContext.cancelJobGroup(runId.toString) + // The above execution may finish before getting interrupted, for example, a Spark job having + // 0 partitions will complete immediately. Then the interrupted status will sneak here. + // + // To handle this case, we do the two things here: + // + // 1. Clean up the resources in `queryExecutionThread.runUninterruptibly`. This may increase + // the waiting time of `stop` but should be minor because the operations here are very fast + // (just sending an RPC message in the same process and stopping a very simple thread). + // 2. Clear the interrupted status at the end so that it won't impact the `runContinuous` + // call. We may clear the interrupted status set by `stop`, but it doesn't affect the query + // termination because `runActivatedStream` will check `state` and exit accordingly. + queryExecutionThread.runUninterruptibly { + try { + epochEndpoint.askSync[Unit](StopContinuousExecutionWrites) + } finally { + SparkEnv.get.rpcEnv.stop(epochEndpoint) + epochUpdateThread.interrupt() + epochUpdateThread.join() + // The following line must be the last line because it may fail if SparkContext is stopped + sparkSession.sparkContext.cancelJobGroup(runId.toString) + } + } + Thread.interrupted() } } From 38556e7676d00abbe1b210d0c9cbf6be3bf19e69 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Wed, 27 Feb 2019 09:52:43 -0800 Subject: [PATCH 15/70] [SPARK-24063][SS] Add maximum epoch queue threshold for ContinuousExecution ## What changes were proposed in this pull request? Continuous processing is waiting on epochs which are not yet complete (for example one partition is not making progress) and stores pending items in queues. These queues are unbounded and can consume up all the memory easily. In this PR I've added `spark.sql.streaming.continuous.epochBacklogQueueSize` configuration possibility to make them bounded. If the related threshold reached then the query will stop with `IllegalStateException`. ## How was this patch tested? Existing + additional unit tests. Closes #23156 from gaborgsomogyi/SPARK-24063. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../apache/spark/sql/internal/SQLConf.scala | 10 +++ .../continuous/ContinuousExecution.scala | 38 +++++++++ .../continuous/EpochCoordinator.scala | 20 +++++ .../continuous/ContinuousSuite.scala | 31 +++++++ .../continuous/EpochCoordinatorSuite.scala | 81 ++++++++++++++++++- 5 files changed, 177 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 56fbeb6b4f798..5900a72f3387e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1441,6 +1441,13 @@ object SQLConf { .booleanConf .createWithDefault(true) + val CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE = + buildConf("spark.sql.streaming.continuous.epochBacklogQueueSize") + .doc("The max number of entries to be stored in queue to wait for late epochs. " + + "If this parameter is exceeded by the size of the queue, stream will stop with an error.") + .intConf + .createWithDefault(10000) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -2073,6 +2080,9 @@ class SQLConf extends Serializable with Logging { def literalPickMinimumPrecision: Boolean = getConf(LITERAL_PICK_MINIMUM_PRECISION) + def continuousStreamingEpochBacklogQueueSize: Int = + getConf(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE) + def continuousStreamingExecutorQueueSize: Int = getConf(CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE) def continuousStreamingExecutorPollIntervalMs: Long = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c2aab88b1b8da..aef556d92cc83 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.streaming.continuous import java.util.UUID import java.util.concurrent.TimeUnit +import java.util.concurrent.atomic.AtomicReference import java.util.function.UnaryOperator import scala.collection.JavaConverters._ @@ -57,6 +58,9 @@ class ContinuousExecution( // For use only in test harnesses. private[sql] var currentEpochCoordinatorId: String = _ + // Throwable that caused the execution to fail + private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null) + override val logicalPlan: LogicalPlan = { val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() var nextSourceId = 0 @@ -253,6 +257,11 @@ class ContinuousExecution( lastExecution.toRdd } } + + val f = failure.get() + if (f != null) { + throw f + } } catch { case t: Throwable if StreamExecution.isInterruptionException(t, sparkSession.sparkContext) && state.get() == RECONFIGURING => @@ -381,6 +390,35 @@ class ContinuousExecution( } } + /** + * Stores error and stops the query execution thread to terminate the query in new thread. + */ + def stopInNewThread(error: Throwable): Unit = { + if (failure.compareAndSet(null, error)) { + logError(s"Query $prettyIdString received exception $error") + stopInNewThread() + } + } + + /** + * Stops the query execution thread to terminate the query in new thread. + */ + private def stopInNewThread(): Unit = { + new Thread("stop-continuous-execution") { + setDaemon(true) + + override def run(): Unit = { + try { + ContinuousExecution.this.stop() + } catch { + case e: Throwable => + logError(e.getMessage, e) + throw e + } + } + }.start() + } + /** * Stops the query execution thread to terminate the query. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala index a99842220424d..decf524f7167c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/EpochCoordinator.scala @@ -123,6 +123,9 @@ private[continuous] class EpochCoordinator( override val rpcEnv: RpcEnv) extends ThreadSafeRpcEndpoint with Logging { + private val epochBacklogQueueSize = + session.sqlContext.conf.continuousStreamingEpochBacklogQueueSize + private var queryWritesStopped: Boolean = false private var numReaderPartitions: Int = _ @@ -212,6 +215,7 @@ private[continuous] class EpochCoordinator( if (!partitionCommits.isDefinedAt((epoch, partitionId))) { partitionCommits.put((epoch, partitionId), message) resolveCommitsAtEpoch(epoch) + checkProcessingQueueBoundaries() } case ReportPartitionOffset(partitionId, epoch, offset) => @@ -223,6 +227,22 @@ private[continuous] class EpochCoordinator( query.addOffset(epoch, stream, thisEpochOffsets.toSeq) resolveCommitsAtEpoch(epoch) } + checkProcessingQueueBoundaries() + } + + private def checkProcessingQueueBoundaries() = { + if (partitionOffsets.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition offset queue has " + + "exceeded its maximum")) + } + if (partitionCommits.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the partition commit queue has " + + "exceeded its maximum")) + } + if (epochsWaitingToBeCommitted.size > epochBacklogQueueSize) { + query.stopInNewThread(new IllegalStateException("Size of the epoch queue has " + + "exceeded its maximum")) + } } override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index 344a8aa55f0c5..d2e489a7d4ad2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream import org.apache.spark.sql.functions._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.test.TestSparkSession @@ -343,3 +344,33 @@ class ContinuousMetaSuite extends ContinuousSuiteBase { } } } + +class ContinuousEpochBacklogSuite extends ContinuousSuiteBase { + import testImplicits._ + + override protected def createSparkSession = new TestSparkSession( + new SparkContext( + "local[1]", + "continuous-stream-test-sql-context", + sparkConf.set("spark.sql.testkey", "true"))) + + // This test forces the backlog to overflow by not standing up enough executors for the query + // to make progress. + test("epoch backlog overflow") { + withSQLConf((CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE.key, "10")) { + val df = spark.readStream + .format("rate") + .option("numPartitions", "2") + .option("rowsPerSecond", "500") + .load() + .select('value) + + testStream(df, useV2Sink = true)( + StartStream(Trigger.Continuous(1)), + ExpectFailure[IllegalStateException] { e => + e.getMessage.contains("queue has exceeded its maximum") + } + ) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala index f74285f4b0fb3..e3498db4194e8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/EpochCoordinatorSuite.scala @@ -17,9 +17,9 @@ package org.apache.spark.sql.streaming.continuous +import org.mockito.{ArgumentCaptor, InOrder} import org.mockito.ArgumentMatchers.{any, eq => eqTo} -import org.mockito.InOrder -import org.mockito.Mockito.{inOrder, never, verify} +import org.mockito.Mockito._ import org.scalatest.BeforeAndAfterEach import org.scalatest.mockito.MockitoSugar @@ -27,6 +27,7 @@ import org.apache.spark._ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.LocalSparkSession import org.apache.spark.sql.execution.streaming.continuous._ +import org.apache.spark.sql.internal.SQLConf.CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite @@ -43,6 +44,7 @@ class EpochCoordinatorSuite private var writeSupport: StreamingWrite = _ private var query: ContinuousExecution = _ private var orderVerifier: InOrder = _ + private val epochBacklogQueueSize = 10 override def beforeEach(): Unit = { val stream = mock[ContinuousStream] @@ -50,7 +52,11 @@ class EpochCoordinatorSuite query = mock[ContinuousExecution] orderVerifier = inOrder(writeSupport, query) - spark = new TestSparkSession() + spark = new TestSparkSession( + new SparkContext( + "local[2]", "test-sql-context", + new SparkConf().set("spark.sql.testkey", "true") + .set(CONTINUOUS_STREAMING_EPOCH_BACKLOG_QUEUE_SIZE, epochBacklogQueueSize))) epochCoordinator = EpochCoordinatorRef.create(writeSupport, stream, query, "test", 1, spark, SparkEnv.get) @@ -186,6 +192,66 @@ class EpochCoordinatorSuite verifyCommitsInOrderOf(List(1, 2, 3, 4, 5)) } + test("several epochs, max epoch backlog reached by partitionOffsets") { + setWriterPartitions(1) + setReaderPartitions(1) + + reportPartitionOffset(0, 1) + // Commit messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + reportPartitionOffset(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition offset queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by partitionCommits") { + setWriterPartitions(1) + setReaderPartitions(1) + + commitPartitionEpoch(0, 1) + // Offset messages not arriving + for (i <- 2 to epochBacklogQueueSize + 1) { + commitPartitionEpoch(0, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 1) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the partition commit queue has exceeded its maximum") + } + + test("several epochs, max epoch backlog reached by epochsWaitingToBeCommitted") { + setWriterPartitions(2) + setReaderPartitions(2) + + commitPartitionEpoch(0, 1) + reportPartitionOffset(0, 1) + + // For partition 2 epoch 1 messages never arriving + // +2 because the first epoch not yet arrived + for (i <- 2 to epochBacklogQueueSize + 2) { + commitPartitionEpoch(0, i) + reportPartitionOffset(0, i) + commitPartitionEpoch(1, i) + reportPartitionOffset(1, i) + } + + makeSynchronousCall() + + for (i <- 1 to epochBacklogQueueSize + 2) { + verifyNoCommitFor(i) + } + verifyStoppedWithException("Size of the epoch queue has exceeded its maximum") + } + private def setWriterPartitions(numPartitions: Int): Unit = { epochCoordinator.askSync[Unit](SetWriterPartitions(numPartitions)) } @@ -221,4 +287,13 @@ class EpochCoordinatorSuite private def verifyCommitsInOrderOf(epochs: Seq[Long]): Unit = { epochs.foreach(verifyCommit) } + + private def verifyStoppedWithException(msg: String): Unit = { + val exceptionCaptor = ArgumentCaptor.forClass(classOf[Throwable]); + verify(query, atLeastOnce()).stopInNewThread(exceptionCaptor.capture()) + + import scala.collection.JavaConverters._ + val throwable = exceptionCaptor.getAllValues.asScala.find(_.getMessage === msg) + assert(throwable != null, "Stream stopped with an exception but expected message is missing") + } } From 3fecdd9f8f5445521c73b58907b59f66cdbdfbd7 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Mar 2019 19:47:54 +0800 Subject: [PATCH 16/70] [SPARK-27064][SS] create StreamingWrite at the beginning of streaming execution ## What changes were proposed in this pull request? According to the [design](https://docs.google.com/document/d/1vI26UEuDpVuOjWw4WPoH2T6y8WAekwtI7qoowhOFnI4/edit?usp=sharing), the life cycle of `StreamingWrite` should be the same as the read side `MicroBatch/ContinuousStream`, i.e. each run of the stream query, instead of each epoch. This PR fixes it. ## How was this patch tested? existing tests Closes #23981 from cloud-fan/dsv2. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../kafka010/KafkaContinuousSinkSuite.scala | 101 +++++------------- .../streaming/MicroBatchExecution.scala | 18 ++-- .../execution/streaming/StreamExecution.scala | 2 +- .../continuous/ContinuousExecution.scala | 20 ++-- .../sources/WriteToMicroBatchDataSource.scala | 39 +++++++ .../sources/StreamingDataSourceV2Suite.scala | 18 +++- 6 files changed, 104 insertions(+), 94 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala index b21037b1340ce..3c3aeebc48b7f 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousSinkSuite.scala @@ -22,9 +22,8 @@ import java.util.Locale import org.apache.kafka.clients.producer.ProducerConfig import org.apache.kafka.common.serialization.ByteArraySerializer import org.scalatest.time.SpanSugar._ -import scala.collection.JavaConverters._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, Row} import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.types.{BinaryType, DataType} @@ -227,39 +226,23 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - /* No topic field or topic option */ - var writer: StreamingQuery = null - var ex: Exception = null - try { - writer = createKafkaWriter(input.toDF())( + val ex = intercept[AnalysisException] { + /* No topic field or topic option */ + createKafkaWriter(input.toDF())( withSelectExpr = "value as key", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } assert(ex.getMessage .toLowerCase(Locale.ROOT) .contains("topic option required when no 'topic' attribute is present")) - try { + val ex2 = intercept[AnalysisException] { /* No value field */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "value as key" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( "required attribute 'value' not found")) } @@ -278,53 +261,30 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { val topic = newTopic() testUtils.createTopic(topic) - var writer: StreamingQuery = null - var ex: Exception = null - try { + val ex = intercept[AnalysisException] { /* topic field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"CAST('1' as INT) as topic", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } assert(ex.getMessage.toLowerCase(Locale.ROOT).contains("topic type must be a string")) - try { + val ex2 = intercept[AnalysisException] { /* value field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( "value attribute type must be a string or binary")) - try { + val ex3 = intercept[AnalysisException] { /* key field wrong type */ - writer = createKafkaWriter(input.toDF())( + createKafkaWriter(input.toDF())( withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" ) - testUtils.sendMessages(inputTopic, Array("1", "2", "3", "4", "5")) - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - } finally { - writer.stop() } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + assert(ex3.getMessage.toLowerCase(Locale.ROOT).contains( "key attribute type must be a string or binary")) } @@ -369,35 +329,22 @@ class KafkaContinuousSinkSuite extends KafkaContinuousTest { .option("kafka.bootstrap.servers", testUtils.brokerAddress) .option("subscribe", inputTopic) .load() - var writer: StreamingQuery = null - var ex: Exception = null - try { - writer = createKafkaWriter( + + val ex = intercept[IllegalArgumentException] { + createKafkaWriter( input.toDF(), withOptions = Map("kafka.key.serializer" -> "foo"))() - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'key.serializer' is not supported")) - } finally { - writer.stop() } + assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'key.serializer' is not supported")) - try { - writer = createKafkaWriter( + val ex2 = intercept[IllegalArgumentException] { + createKafkaWriter( input.toDF(), withOptions = Map("kafka.value.serializer" -> "foo"))() - eventually(timeout(streamingTimeout)) { - assert(writer.exception.isDefined) - ex = writer.exception.get - } - assert(ex.getMessage.toLowerCase(Locale.ROOT).contains( - "kafka option 'value.serializer' is not supported")) - } finally { - writer.stop() } + assert(ex2.getMessage.toLowerCase(Locale.ROOT).contains( + "kafka option 'value.serializer' is not supported")) } test("generic - write big data with small producer buffer") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index de7cbe25ceb3b..bedcb9f8d4e12 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -26,8 +26,8 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, CurrentBatch import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LocalRelation, LogicalPlan, Project} import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.execution.SQLExecution -import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2, WriteToDataSourceV2Exec} -import org.apache.spark.sql.execution.streaming.sources.{MicroBatchWrite, RateControlMicroBatchStream} +import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relation, StreamWriterCommitProgress, WriteToDataSourceV2Exec} +import org.apache.spark.sql.execution.streaming.sources.{RateControlMicroBatchStream, WriteToMicroBatchDataSource} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} @@ -122,7 +122,14 @@ class MicroBatchExecution( case r: StreamingDataSourceV2Relation => r.stream } uniqueSources = sources.distinct - _logicalPlan + + sink match { + case s: SupportsStreamingWrite => + val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan) + WriteToMicroBatchDataSource(streamingWrite, _logicalPlan) + + case _ => _logicalPlan + } } /** @@ -513,9 +520,8 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case s: SupportsStreamingWrite => - val streamingWrite = createStreamingWrite(s, extraOptions, newAttributePlan) - WriteToDataSourceV2(new MicroBatchWrite(currentBatchId, streamingWrite), newAttributePlan) + case _: SupportsStreamingWrite => + newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].createPlan(currentBatchId) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index bba640eea7e5b..180a23c765dd3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -585,7 +585,7 @@ abstract class StreamExecution( options: Map[String, String], inputPlan: LogicalPlan): StreamingWrite = { val writeBuilder = table.newWriteBuilder(new DataSourceOptions(options.asJava)) - .withQueryId(runId.toString) + .withQueryId(id.toString) .withInputDataSchema(inputPlan.schema) outputMode match { case Append => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index aef556d92cc83..f55a45d2cee73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -61,7 +61,7 @@ class ContinuousExecution( // Throwable that caused the execution to fail private val failure: AtomicReference[Throwable] = new AtomicReference[Throwable](null) - override val logicalPlan: LogicalPlan = { + override val logicalPlan: WriteToContinuousDataSource = { val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() var nextSourceId = 0 val _logicalPlan = analyzedPlan.transform { @@ -88,7 +88,8 @@ class ContinuousExecution( } uniqueSources = sources.distinct - _logicalPlan + WriteToContinuousDataSource( + createStreamingWrite(sink, extraOptions, _logicalPlan), _logicalPlan) } private val triggerExecutor = trigger match { @@ -178,13 +179,10 @@ class ContinuousExecution( "CurrentTimestamp and CurrentDate not yet supported for continuous processing") } - val streamingWrite = createStreamingWrite(sink, extraOptions, withNewSources) - val planWithSink = WriteToContinuousDataSource(streamingWrite, withNewSources) - reportTimeTaken("queryPlanning") { lastExecution = new IncrementalExecution( sparkSessionForQuery, - planWithSink, + withNewSources, outputMode, checkpointFile("state"), id, @@ -194,7 +192,7 @@ class ContinuousExecution( lastExecution.executedPlan // Force the lazy generation of execution plan } - val stream = planWithSink.collect { + val stream = withNewSources.collect { case relation: StreamingDataSourceV2Relation => relation.stream.asInstanceOf[ContinuousStream] }.head @@ -215,7 +213,13 @@ class ContinuousExecution( // Use the parent Spark session for the endpoint since it's where this query ID is registered. val epochEndpoint = EpochCoordinatorRef.create( - streamingWrite, stream, this, epochCoordinatorId, currentBatchId, sparkSession, SparkEnv.get) + logicalPlan.write, + stream, + this, + epochCoordinatorId, + currentBatchId, + sparkSession, + SparkEnv.get) val epochUpdateThread = new Thread(new Runnable { override def run: Unit = { try { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala new file mode 100644 index 0000000000000..a3f58fa966fe8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/WriteToMicroBatchDataSource.scala @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.streaming.sources + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.v2.WriteToDataSourceV2 +import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite + +/** + * The logical plan for writing data to a micro-batch stream. + * + * Note that this logical plan does not have a corresponding physical plan, as it will be converted + * to [[WriteToDataSourceV2]] with [[MicroBatchWrite]] before execution. + */ +case class WriteToMicroBatchDataSource(write: StreamingWrite, query: LogicalPlan) + extends LogicalPlan { + override def children: Seq[LogicalPlan] = Seq(query) + override def output: Seq[Attribute] = Nil + + def createPlan(batchId: Long): WriteToDataSourceV2 = { + WriteToDataSourceV2(new MicroBatchWrite(batchId, write), query) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index c841793fdd4a7..553b48398c9ba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -26,7 +26,8 @@ import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ -import org.apache.spark.sql.sources.v2.writer.WriteBuilder +import org.apache.spark.sql.sources.v2.writer.{WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils @@ -59,6 +60,19 @@ class FakeScanBuilder extends ScanBuilder with Scan { override def toContinuousStream(checkpointLocation: String): ContinuousStream = new FakeDataStream } +class FakeWriteBuilder extends WriteBuilder with StreamingWrite { + override def buildForStreaming(): StreamingWrite = this + override def createStreamingWriterFactory(): StreamingDataWriterFactory = { + throw new IllegalStateException("fake sink - cannot actually write") + } + override def commit(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + throw new IllegalStateException("fake sink - cannot actually write") + } + override def abort(epochId: Long, messages: Array[WriterCommitMessage]): Unit = { + throw new IllegalStateException("fake sink - cannot actually write") + } +} + trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) @@ -75,7 +89,7 @@ trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { - throw new IllegalStateException("fake sink - cannot actually write") + new FakeWriteBuilder } } From 1609b3fd6cb9339e9cb026bf98e341b74b7b20c3 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 14 Mar 2019 01:23:27 +0800 Subject: [PATCH 17/70] [SPARK-27106][SQL] merge CaseInsensitiveStringMap and DataSourceOptions It's a little awkward to have 2 different classes(`CaseInsensitiveStringMap` and `DataSourceOptions`) to present the options in data source and catalog API. This PR merges these 2 classes, while keeping the name `CaseInsensitiveStringMap`, which is more precise. existing tests Closes #24025 from cloud-fan/option. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/kafka010/KafkaContinuousStream.scala | 3 +- .../sql/kafka010/KafkaMicroBatchStream.scala | 7 +- .../kafka010/KafkaOffsetRangeCalculator.scala | 6 +- .../sql/kafka010/KafkaSourceProvider.scala | 19 +- .../kafka010/KafkaMicroBatchSourceSuite.scala | 4 +- .../KafkaOffsetRangeCalculatorSuite.scala | 8 +- .../sql/util/CaseInsensitiveStringMap.java | 66 +++++- .../util/CaseInsensitiveStringMapSuite.java | 48 ---- .../util/CaseInsensitiveStringMapSuite.scala} | 62 ++---- .../sql/sources/v2/DataSourceOptions.java | 210 ------------------ .../sql/sources/v2/SupportsBatchRead.java | 5 +- .../sql/sources/v2/SupportsBatchWrite.java | 5 +- .../sources/v2/SupportsContinuousRead.java | 5 +- .../sources/v2/SupportsMicroBatchRead.java | 5 +- .../spark/sql/sources/v2/SupportsRead.java | 7 +- .../sources/v2/SupportsStreamingWrite.java | 5 +- .../spark/sql/sources/v2/SupportsWrite.java | 5 +- .../spark/sql/sources/v2/TableProvider.java | 5 +- .../apache/spark/sql/DataFrameReader.scala | 18 +- .../apache/spark/sql/DataFrameWriter.scala | 8 +- .../datasources/FallbackOrcDataSourceV2.scala | 13 +- .../datasources/noop/NoopDataSource.scala | 5 +- .../v2/DataSourceV2Implicits.scala | 8 +- .../datasources/v2/DataSourceV2Relation.scala | 7 +- .../datasources/v2/DataSourceV2Strategy.scala | 8 +- .../datasources/v2/FileDataSourceV2.scala | 12 + .../execution/datasources/v2/FileTable.scala | 22 +- .../datasources/v2/FileWriteBuilder.scala | 15 +- .../v2/WriteToDataSourceV2Exec.scala | 18 +- .../datasources/v2/orc/OrcDataSourceV2.scala | 21 +- .../datasources/v2/orc/OrcScanBuilder.scala | 7 +- .../datasources/v2/orc/OrcTable.scala | 14 +- .../datasources/v2/orc/OrcWriteBuilder.scala | 6 +- .../streaming/MicroBatchExecution.scala | 3 +- .../execution/streaming/StreamExecution.scala | 5 +- .../streaming/StreamingRelation.scala | 3 +- .../sql/execution/streaming/console.scala | 5 +- .../continuous/ContinuousExecution.scala | 6 +- .../ContinuousRateStreamSource.scala | 6 +- .../ContinuousTextSocketSource.scala | 4 +- .../sql/execution/streaming/memory.scala | 7 +- .../streaming/sources/ConsoleWrite.scala | 4 +- .../sources/ForeachWriterTable.scala | 5 +- .../sources/RateStreamMicroBatchStream.scala | 6 +- .../sources/RateStreamProvider.scala | 9 +- .../sources/TextSocketMicroBatchStream.scala | 4 +- .../sources/TextSocketSourceProvider.scala | 17 +- .../streaming/sources/memoryV2.scala | 5 +- .../sql/streaming/DataStreamReader.scala | 5 +- .../sql/streaming/DataStreamWriter.scala | 5 +- .../sources/v2/JavaAdvancedDataSourceV2.java | 6 +- .../sources/v2/JavaColumnarDataSourceV2.java | 6 +- .../v2/JavaPartitionAwareDataSource.java | 6 +- .../v2/JavaReportStatisticsDataSource.java | 6 +- .../v2/JavaSchemaRequiredDataSource.java | 8 +- .../sources/v2/JavaSimpleDataSourceV2.java | 6 +- .../datasources/orc/OrcFilterSuite.scala | 5 +- .../sources/RateStreamProviderSuite.scala | 11 +- .../sources/TextSocketStreamSuite.scala | 18 +- .../sql/sources/v2/DataSourceV2Suite.scala | 35 +-- .../v2/FileDataSourceV2FallBackSuite.scala | 11 +- .../sources/v2/SimpleWritableDataSource.scala | 14 +- .../sources/StreamingDataSourceV2Suite.scala | 37 +-- 63 files changed, 365 insertions(+), 560 deletions(-) delete mode 100644 sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java rename sql/{core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala => catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala} (52%) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index 0e6171724402e..d60ee1cadd195 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -37,8 +37,7 @@ import org.apache.spark.sql.sources.v2.reader.streaming._ * @param offsetReader a reader used to get kafka offsets. Note that the actual data will be * read by per-task consumers generated later. * @param kafkaParams String params for per-task Kafka consumers. - * @param sourceOptions The [[org.apache.spark.sql.sources.v2.DataSourceOptions]] params which - * are not Kafka consumer params. + * @param sourceOptions Params which are not Kafka consumer params. * @param metadataPath Path to a directory this reader can use for writing metadata. * @param initialOffsets The Kafka offsets to start reading data at. * @param failOnDataLoss Flag indicating whether reading should fail in data loss diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 337a51ef7fd80..ae866b48ef74b 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -33,9 +33,9 @@ import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.streaming.{HDFSMetadataLog, SerializedOffset} import org.apache.spark.sql.execution.streaming.sources.RateControlMicroBatchStream import org.apache.spark.sql.kafka010.KafkaSourceProvider.{INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_FALSE, INSTRUCTION_FOR_FAIL_ON_DATA_LOSS_TRUE} -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.UninterruptibleThread /** @@ -57,7 +57,7 @@ import org.apache.spark.util.UninterruptibleThread private[kafka010] class KafkaMicroBatchStream( kafkaOffsetReader: KafkaOffsetReader, executorKafkaParams: ju.Map[String, Object], - options: DataSourceOptions, + options: CaseInsensitiveStringMap, metadataPath: String, startingOffsets: KafkaOffsetRangeLimit, failOnDataLoss: Boolean) extends RateControlMicroBatchStream with Logging { @@ -66,8 +66,7 @@ private[kafka010] class KafkaMicroBatchStream( "kafkaConsumer.pollTimeoutMs", SparkEnv.get.conf.getTimeAsSeconds("spark.network.timeout", "120s") * 1000L) - private val maxOffsetsPerTrigger = - Option(options.get("maxOffsetsPerTrigger").orElse(null)).map(_.toLong) + private val maxOffsetsPerTrigger = Option(options.get("maxOffsetsPerTrigger")).map(_.toLong) private val rangeCalculator = KafkaOffsetRangeCalculator(options) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala index 6008794924052..1af8404b89c68 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculator.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -91,8 +91,8 @@ private[kafka010] class KafkaOffsetRangeCalculator(val minPartitions: Option[Int private[kafka010] object KafkaOffsetRangeCalculator { - def apply(options: DataSourceOptions): KafkaOffsetRangeCalculator = { - val optionalValue = Option(options.get("minPartitions").orElse(null)).map(_.toInt) + def apply(options: CaseInsensitiveStringMap): KafkaOffsetRangeCalculator = { + val optionalValue = Option(options.get("minPartitions")).map(_.toInt) new KafkaOffsetRangeCalculator(optionalValue) } } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 01bb1536aa6c5..12f09afdb238d 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * The provider class for all Kafka readers and writers. It is designed such that it throws @@ -102,8 +103,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister failOnDataLoss(caseInsensitiveParams)) } - override def getTable(options: DataSourceOptions): KafkaTable = { - new KafkaTable(strategy(options.asMap().asScala.toMap)) + override def getTable(options: CaseInsensitiveStringMap): KafkaTable = { + new KafkaTable(strategy(options.asScala.toMap)) } /** @@ -357,11 +358,11 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def schema(): StructType = KafkaOffsetReader.kafkaSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder { private var inputSchema: StructType = _ @@ -374,20 +375,20 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister import scala.collection.JavaConverters._ assert(inputSchema != null) - val topic = Option(options.get(TOPIC_OPTION_KEY).orElse(null)).map(_.trim) - val producerParams = kafkaParamsForProducer(options.asMap.asScala.toMap) + val topic = Option(options.get(TOPIC_OPTION_KEY)).map(_.trim) + val producerParams = kafkaParamsForProducer(options.asScala.toMap) new KafkaStreamingWrite(topic, producerParams, inputSchema) } } } } - class KafkaScan(options: DataSourceOptions) extends Scan { + class KafkaScan(options: CaseInsensitiveStringMap) extends Scan { override def readSchema(): StructType = KafkaOffsetReader.kafkaSchema override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { - val parameters = options.asMap().asScala.toMap + val parameters = options.asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group @@ -416,7 +417,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { - val parameters = options.asMap().asScala.toMap + val parameters = options.asScala.toMap validateStreamOptions(parameters) // Each running query should use its own group id. Otherwise, the query may be only assigned // partial data since Kafka will assign partitions to multiple consumers having the same group diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 8fd5790d753af..21634ae2abfa1 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -41,10 +41,10 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with KafkaTest { @@ -1118,7 +1118,7 @@ class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBase { "kafka.bootstrap.servers" -> testUtils.brokerAddress, "subscribe" -> topic ) ++ Option(minPartitions).map { p => "minPartitions" -> p} - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = provider.getTable(dsOptions) val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) val inputPartitions = stream.planInputPartitions( diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala index 2ccf3e291bea7..7ffdaab3e74fb 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaOffsetRangeCalculatorSuite.scala @@ -22,13 +22,13 @@ import scala.collection.JavaConverters._ import org.apache.kafka.common.TopicPartition import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.sources.v2.DataSourceOptions +import org.apache.spark.sql.util.CaseInsensitiveStringMap class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { def testWithMinPartitions(name: String, minPartition: Int) (f: KafkaOffsetRangeCalculator => Unit): Unit = { - val options = new DataSourceOptions(Map("minPartitions" -> minPartition.toString).asJava) + val options = new CaseInsensitiveStringMap(Map("minPartitions" -> minPartition.toString).asJava) test(s"with minPartition = $minPartition: $name") { f(KafkaOffsetRangeCalculator(options)) } @@ -36,7 +36,7 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { test("with no minPartition: N TopicPartitions to N offset ranges") { - val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( calc.getRanges( fromOffsets = Map(tp1 -> 1), @@ -64,7 +64,7 @@ class KafkaOffsetRangeCalculatorSuite extends SparkFunSuite { } test("with no minPartition: empty ranges ignored") { - val calc = KafkaOffsetRangeCalculator(DataSourceOptions.empty()) + val calc = KafkaOffsetRangeCalculator(CaseInsensitiveStringMap.empty()) assert( calc.getRanges( fromOffsets = Map(tp1 -> 1, tp2 -> 1), diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java index 8c5a6c61d8658..704d90ed60adc 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java @@ -31,19 +31,20 @@ * This is used to pass options to v2 implementations to ensure consistent case insensitivity. *

* Methods that return keys in this map, like {@link #entrySet()} and {@link #keySet()}, return - * keys converted to lower case. + * keys converted to lower case. This map doesn't allow null key. */ @Experimental public class CaseInsensitiveStringMap implements Map { public static CaseInsensitiveStringMap empty() { - return new CaseInsensitiveStringMap(); + return new CaseInsensitiveStringMap(new HashMap<>(0)); } private final Map delegate; - private CaseInsensitiveStringMap() { - this.delegate = new HashMap<>(); + public CaseInsensitiveStringMap(Map originalMap) { + this.delegate = new HashMap<>(originalMap.size()); + putAll(originalMap); } @Override @@ -56,9 +57,13 @@ public boolean isEmpty() { return delegate.isEmpty(); } + private String toLowerCase(Object key) { + return key.toString().toLowerCase(Locale.ROOT); + } + @Override public boolean containsKey(Object key) { - return delegate.containsKey(key.toString().toLowerCase(Locale.ROOT)); + return delegate.containsKey(toLowerCase(key)); } @Override @@ -68,17 +73,17 @@ public boolean containsValue(Object value) { @Override public String get(Object key) { - return delegate.get(key.toString().toLowerCase(Locale.ROOT)); + return delegate.get(toLowerCase(key)); } @Override public String put(String key, String value) { - return delegate.put(key.toLowerCase(Locale.ROOT), value); + return delegate.put(toLowerCase(key), value); } @Override public String remove(Object key) { - return delegate.remove(key.toString().toLowerCase(Locale.ROOT)); + return delegate.remove(toLowerCase(key)); } @Override @@ -107,4 +112,49 @@ public Collection values() { public Set> entrySet() { return delegate.entrySet(); } + + /** + * Returns the boolean value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public boolean getBoolean(String key, boolean defaultValue) { + String value = get(key); + // We can't use `Boolean.parseBoolean` here, as it returns false for invalid strings. + if (value == null) { + return defaultValue; + } else if (value.equalsIgnoreCase("true")) { + return true; + } else if (value.equalsIgnoreCase("false")) { + return false; + } else { + throw new IllegalArgumentException(value + " is not a boolean string."); + } + } + + /** + * Returns the integer value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public int getInt(String key, int defaultValue) { + String value = get(key); + return value == null ? defaultValue : Integer.parseInt(value); + } + + /** + * Returns the long value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public long getLong(String key, long defaultValue) { + String value = get(key); + return value == null ? defaultValue : Long.parseLong(value); + } + + /** + * Returns the double value to which the specified key is mapped, + * or defaultValue if there is no mapping for the key. The key match is case-insensitive. + */ + public double getDouble(String key, double defaultValue) { + String value = get(key); + return value == null ? defaultValue : Double.parseDouble(value); + } } diff --git a/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java b/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java deleted file mode 100644 index 76392777d42a4..0000000000000 --- a/sql/catalyst/src/test/java/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.util; - -import org.junit.Assert; -import org.junit.Test; - -import java.util.HashSet; -import java.util.Set; - -public class CaseInsensitiveStringMapSuite { - @Test - public void testPutAndGet() { - CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); - options.put("kEy", "valUE"); - - Assert.assertEquals("Should return correct value for lower-case key", - "valUE", options.get("key")); - Assert.assertEquals("Should return correct value for upper-case key", - "valUE", options.get("KEY")); - } - - @Test - public void testKeySet() { - CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); - options.put("kEy", "valUE"); - - Set expectedKeySet = new HashSet<>(); - expectedKeySet.add("key"); - - Assert.assertEquals("Should return lower-case key set", expectedKeySet, options.keySet()); - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala similarity index 52% rename from sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala index cfa69a86de1a7..623ddeb140254 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceOptionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala @@ -15,31 +15,29 @@ * limitations under the License. */ -package org.apache.spark.sql.sources.v2 +package org.apache.spark.sql.util import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite -/** - * A simple test suite to verify `DataSourceOptions`. - */ -class DataSourceOptionsSuite extends SparkFunSuite { +class CaseInsensitiveStringMapSuite extends SparkFunSuite { - test("key is case-insensitive") { - val options = new DataSourceOptions(Map("foo" -> "bar").asJava) - assert(options.get("foo").get() == "bar") - assert(options.get("FoO").get() == "bar") - assert(!options.get("abc").isPresent) + test("put and get") { + val options = CaseInsensitiveStringMap.empty() + options.put("kEy", "valUE") + assert(options.get("key") == "valUE") + assert(options.get("KEY") == "valUE") } - test("value is case-sensitive") { - val options = new DataSourceOptions(Map("foo" -> "bAr").asJava) - assert(options.get("foo").get == "bAr") + test("key and value set") { + val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava) + assert(options.keySet().asScala == Set("key")) + assert(options.values().asScala.toSeq == Seq("valUE")) } test("getInt") { - val options = new DataSourceOptions(Map("numFOo" -> "1", "foo" -> "bar").asJava) + val options = new CaseInsensitiveStringMap(Map("numFOo" -> "1", "foo" -> "bar").asJava) assert(options.getInt("numFOO", 10) == 1) assert(options.getInt("numFOO2", 10) == 10) @@ -49,17 +47,20 @@ class DataSourceOptionsSuite extends SparkFunSuite { } test("getBoolean") { - val options = new DataSourceOptions( + val options = new CaseInsensitiveStringMap( Map("isFoo" -> "true", "isFOO2" -> "false", "foo" -> "bar").asJava) assert(options.getBoolean("isFoo", false)) assert(!options.getBoolean("isFoo2", true)) assert(options.getBoolean("isBar", true)) assert(!options.getBoolean("isBar", false)) - assert(!options.getBoolean("FOO", true)) + + intercept[IllegalArgumentException] { + options.getBoolean("FOO", true) + } } test("getLong") { - val options = new DataSourceOptions(Map("numFoo" -> "9223372036854775807", + val options = new CaseInsensitiveStringMap(Map("numFoo" -> "9223372036854775807", "foo" -> "bar").asJava) assert(options.getLong("numFOO", 0L) == 9223372036854775807L) assert(options.getLong("numFoo2", -1L) == -1L) @@ -70,7 +71,7 @@ class DataSourceOptionsSuite extends SparkFunSuite { } test("getDouble") { - val options = new DataSourceOptions(Map("numFoo" -> "922337.1", + val options = new CaseInsensitiveStringMap(Map("numFoo" -> "922337.1", "foo" -> "bar").asJava) assert(options.getDouble("numFOO", 0d) == 922337.1d) assert(options.getDouble("numFoo2", -1.02d) == -1.02d) @@ -79,29 +80,4 @@ class DataSourceOptionsSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } - - test("standard options") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATH_KEY -> "abc", - DataSourceOptions.TABLE_KEY -> "tbl").asJava) - - assert(options.paths().toSeq == Seq("abc")) - assert(options.tableName().get() == "tbl") - assert(!options.databaseName().isPresent) - } - - test("standard options with both singular path and multi-paths") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATH_KEY -> "abc", - DataSourceOptions.PATHS_KEY -> """["c", "d"]""").asJava) - - assert(options.paths().toSeq == Seq("abc", "c", "d")) - } - - test("standard options with only multi-paths") { - val options = new DataSourceOptions(Map( - DataSourceOptions.PATHS_KEY -> """["c", "d\"e"]""").asJava) - - assert(options.paths().toSeq == Seq("c", "d\"e")) - } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java deleted file mode 100644 index 00af0bf1b172c..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/DataSourceOptions.java +++ /dev/null @@ -1,210 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Locale; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Stream; - -import com.fasterxml.jackson.databind.ObjectMapper; - -import org.apache.spark.annotation.Evolving; - -/** - * An immutable string-to-string map in which keys are case-insensitive. This is used to represent - * data source options. - * - * Each data source implementation can define its own options and teach its users how to set them. - * Spark doesn't have any restrictions about what options a data source should or should not have. - * Instead Spark defines some standard options that data sources can optionally adopt. It's possible - * that some options are very common and many data sources use them. However different data - * sources may define the common options(key and meaning) differently, which is quite confusing to - * end users. - * - * The standard options defined by Spark: - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - * - *
Option keyOption value
pathA path string of the data files/directories, like - * path1, /absolute/file2, path3/*. The path can - * either be relative or absolute, points to either file or directory, and can contain - * wildcards. This option is commonly used by file-based data sources.
pathsA JSON array style paths string of the data files/directories, like - * ["path1", "/absolute/file2"]. The format of each path is same as the - * path option, plus it should follow JSON string literal format, e.g. quotes - * should be escaped, pa\"th means pa"th. - *
tableA table name string representing the table name directly without any interpretation. - * For example, db.tbl means a table called db.tbl, not a table called tbl - * inside database db. `t*b.l` means a table called `t*b.l`, not t*b.l.
databaseA database name string representing the database name directly without any - * interpretation, which is very similar to the table name option.
- */ -@Evolving -public class DataSourceOptions { - private final Map keyLowerCasedMap; - - private String toLowerCase(String key) { - return key.toLowerCase(Locale.ROOT); - } - - public static DataSourceOptions empty() { - return new DataSourceOptions(new HashMap<>()); - } - - public DataSourceOptions(Map originalMap) { - keyLowerCasedMap = new HashMap<>(originalMap.size()); - for (Map.Entry entry : originalMap.entrySet()) { - keyLowerCasedMap.put(toLowerCase(entry.getKey()), entry.getValue()); - } - } - - public Map asMap() { - return new HashMap<>(keyLowerCasedMap); - } - - /** - * Returns the option value to which the specified key is mapped, case-insensitively. - */ - public Optional get(String key) { - return Optional.ofNullable(keyLowerCasedMap.get(toLowerCase(key))); - } - - /** - * Returns the boolean value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public boolean getBoolean(String key, boolean defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Boolean.parseBoolean(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the integer value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public int getInt(String key, int defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Integer.parseInt(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the long value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public long getLong(String key, long defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Long.parseLong(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * Returns the double value to which the specified key is mapped, - * or defaultValue if there is no mapping for the key. The key match is case-insensitive - */ - public double getDouble(String key, double defaultValue) { - String lcaseKey = toLowerCase(key); - return keyLowerCasedMap.containsKey(lcaseKey) ? - Double.parseDouble(keyLowerCasedMap.get(lcaseKey)) : defaultValue; - } - - /** - * The option key for singular path. - */ - public static final String PATH_KEY = "path"; - - /** - * The option key for multiple paths. - */ - public static final String PATHS_KEY = "paths"; - - /** - * The option key for table name. - */ - public static final String TABLE_KEY = "table"; - - /** - * The option key for database name. - */ - public static final String DATABASE_KEY = "database"; - - /** - * The option key for whether to check existence of files for a table. - */ - public static final String CHECK_FILES_EXIST_KEY = "check_files_exist"; - - /** - * Returns all the paths specified by both the singular path option and the multiple - * paths option. - */ - public String[] paths() { - String[] singularPath = - get(PATH_KEY).map(s -> new String[]{s}).orElseGet(() -> new String[0]); - Optional pathsStr = get(PATHS_KEY); - if (pathsStr.isPresent()) { - ObjectMapper objectMapper = new ObjectMapper(); - try { - String[] paths = objectMapper.readValue(pathsStr.get(), String[].class); - return Stream.of(singularPath, paths).flatMap(Stream::of).toArray(String[]::new); - } catch (IOException e) { - return singularPath; - } - } else { - return singularPath; - } - } - - /** - * Returns the value of the table name option. - */ - public Optional tableName() { - return get(TABLE_KEY); - } - - /** - * Returns the value of the database name option. - */ - public Optional databaseName() { - return get(DATABASE_KEY); - } - - public Boolean checkFilesExist() { - Optional result = get(CHECK_FILES_EXIST_KEY); - return result.isPresent() && result.get().equals("true"); - } -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java index 6c5a95d2a75b7..ea7c5d2b108f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java @@ -20,13 +20,14 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.Scan; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. *

* If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toBatch()} implemented. + * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} + * that builds {@link Scan} with {@link Scan#toBatch()} implemented. *

*/ @Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java index b2cd97a2f5332..09e23f84fd6bf 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java @@ -19,13 +19,14 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. *

* If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} - * with {@link WriteBuilder#buildForBatch()} implemented. + * {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)} must return a + * {@link WriteBuilder} with {@link WriteBuilder#buildForBatch()} implemented. *

*/ @Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java index b7fa3f24a238c..5cc9848d9da89 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java @@ -20,14 +20,15 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.Scan; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with * continuous mode. *

* If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toContinuousStream(String)} implemented. + * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} + * that builds {@link Scan} with {@link Scan#toContinuousStream(String)} implemented. *

*/ @Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java index 9408e323f9da1..c98f3f1aa5cba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java @@ -20,14 +20,15 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.v2.reader.Scan; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with * micro-batch mode. *

* If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(DataSourceOptions)} must return a {@link ScanBuilder} that - * builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented. + * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} + * that builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented. *

*/ @Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java index 5031c71c0fd4d..14990effeda37 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -19,11 +19,12 @@ import org.apache.spark.sql.sources.v2.reader.Scan; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An internal base interface of mix-in interfaces for readable {@link Table}. This adds - * {@link #newScanBuilder(DataSourceOptions)} that is used to create a scan for batch, micro-batch, - * or continuous processing. + * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch, + * micro-batch, or continuous processing. */ interface SupportsRead extends Table { @@ -34,5 +35,5 @@ interface SupportsRead extends Table { * @param options The options for reading, which is an immutable case-insensitive * string-to-string map. */ - ScanBuilder newScanBuilder(DataSourceOptions options); + ScanBuilder newScanBuilder(CaseInsensitiveStringMap options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java index 1050d35250c1f..ac11e483c18c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java @@ -20,13 +20,14 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.execution.streaming.BaseStreamingSink; import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An empty mix-in interface for {@link Table}, to indicate this table supports streaming write. *

* If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(DataSourceOptions)} must return a {@link WriteBuilder} - * with {@link WriteBuilder#buildForStreaming()} implemented. + * {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)} must return a + * {@link WriteBuilder} with {@link WriteBuilder#buildForStreaming()} implemented. *

*/ @Evolving diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java index ecdfe20730254..f0d8e44f15287 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java @@ -19,10 +19,11 @@ import org.apache.spark.sql.sources.v2.writer.BatchWrite; import org.apache.spark.sql.sources.v2.writer.WriteBuilder; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * An internal base interface of mix-in interfaces for writable {@link Table}. This adds - * {@link #newWriteBuilder(DataSourceOptions)} that is used to create a write + * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write * for batch or streaming. */ interface SupportsWrite extends Table { @@ -31,5 +32,5 @@ interface SupportsWrite extends Table { * Returns a {@link WriteBuilder} which can be used to create {@link BatchWrite}. Spark will call * this method to configure each data source write. */ - WriteBuilder newWriteBuilder(DataSourceOptions options); + WriteBuilder newWriteBuilder(CaseInsensitiveStringMap options); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index a9b83b6de9950..04ad8fd90be9f 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -20,6 +20,7 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** * The base interface for v2 data sources which don't have a real catalog. Implementations must @@ -37,7 +38,7 @@ public interface TableProvider { * @param options the user-specified options that can identify a table, e.g. file path, Kafka * topic name, etc. It's an immutable case-insensitive string-to-string map. */ - Table getTable(DataSourceOptions options); + Table getTable(CaseInsensitiveStringMap options); /** * Return a {@link Table} instance to do read/write with user-specified schema and options. @@ -50,7 +51,7 @@ public interface TableProvider { * @param schema the user-specified schema. * @throws UnsupportedOperationException */ - default Table getTable(DataSourceOptions options, StructType schema) { + default Table getTable(CaseInsensitiveStringMap options, StructType schema) { String name; if (this instanceof DataSourceRegister) { name = ((DataSourceRegister) this).shortName(); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index ffa19895ee3c7..2235217b9c1ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -40,6 +40,7 @@ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, FileTable} import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String /** @@ -176,7 +177,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { */ def load(path: String): DataFrame = { // force invocation of `load(...varargs...)` - option(DataSourceOptions.PATH_KEY, path).load(Seq.empty: _*) + option("path", path).load(Seq.empty: _*) } /** @@ -206,20 +207,23 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = sparkSession.sessionState.conf) - val pathsOption = { + val pathsOption = if (paths.isEmpty) { + None + } else { val objectMapper = new ObjectMapper() - DataSourceOptions.PATHS_KEY -> objectMapper.writeValueAsString(paths.toArray) + Some("paths" -> objectMapper.writeValueAsString(paths.toArray)) } - val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "true" - val finalOptions = sessionOptions ++ extraOptions.toMap + pathsOption + checkFilesExistsOption - val dsOptions = new DataSourceOptions(finalOptions.asJava) + // TODO SPARK-27113: remove this option. + val checkFilesExistsOpt = "check_files_exist" -> "true" + val finalOptions = sessionOptions ++ extraOptions.toMap ++ pathsOption + checkFilesExistsOpt + val dsOptions = new CaseInsensitiveStringMap(finalOptions.asJava) val table = userSpecifiedSchema match { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } table match { case _: SupportsBatchRead => - Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, finalOptions)) + Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions)) case _ => loadV1Source(paths: _*) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index b5cfa85f6fb21..9f766cfccdf93 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -35,6 +35,7 @@ import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to write a [[Dataset]] to external storage systems (e.g. file systems, @@ -260,12 +261,13 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val provider = cls.getConstructor().newInstance().asInstanceOf[TableProvider] val sessionOptions = DataSourceV2Utils.extractSessionConfigs( provider, session.sessionState.conf) - val checkFilesExistsOption = DataSourceOptions.CHECK_FILES_EXIST_KEY -> "false" + // TODO SPARK-27113: remove this option. + val checkFilesExistsOption = "check_files_exist" -> "false" val options = sessionOptions ++ extraOptions + checkFilesExistsOption - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) provider.getTable(dsOptions) match { case table: SupportsBatchWrite => - lazy val relation = DataSourceV2Relation.create(table, options) + lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => runCommand(df.sparkSession, "save") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala index e22d6a6d399a5..7c72495548e3a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FallbackOrcDataSourceV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources +import scala.collection.JavaConverters._ + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.plans.logical.{InsertIntoTable, LogicalPlan} import org.apache.spark.sql.catalyst.rules.Rule @@ -33,10 +35,15 @@ import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable */ class FallbackOrcDataSourceV2(sparkSession: SparkSession) extends Rule[LogicalPlan] { override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - case i @ InsertIntoTable(d @DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => + case i @ InsertIntoTable(d @ DataSourceV2Relation(table: OrcTable, _, _), _, _, _, _) => val v1FileFormat = new OrcFileFormat - val relation = HadoopFsRelation(table.fileIndex, table.fileIndex.partitionSchema, - table.schema(), None, v1FileFormat, d.options)(sparkSession) + val relation = HadoopFsRelation( + table.fileIndex, + table.fileIndex.partitionSchema, + table.schema(), + None, + v1FileFormat, + d.options.asScala.toMap)(sparkSession) i.copy(table = LogicalRelation(relation)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 22a74e3ccaeee..aa2a5e9a06fbd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * This is no-op datasource. It does not do anything besides consuming its input. @@ -31,11 +32,11 @@ import org.apache.spark.sql.types.StructType */ class NoopDataSource extends TableProvider with DataSourceRegister { override def shortName(): String = "noop" - override def getTable(options: DataSourceOptions): Table = NoopTable + override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = NoopWriteBuilder + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index c8542bfe5e59b..2081af35ce2d1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -17,10 +17,8 @@ package org.apache.spark.sql.execution.datasources.v2 -import scala.collection.JavaConverters._ - import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} object DataSourceV2Implicits { implicit class TableHelper(table: Table) { @@ -42,8 +40,4 @@ object DataSourceV2Implicits { } } } - - implicit class OptionsHelper(options: Map[String, String]) { - def toDataSourceOptions: DataSourceOptions = new DataSourceOptions(options.asJava) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 891694be46291..17407827d0564 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Statistics => V2Statistics, _} import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.sources.v2.writer._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A logical plan representing a data source v2 table. @@ -36,7 +37,7 @@ import org.apache.spark.sql.sources.v2.writer._ case class DataSourceV2Relation( table: Table, output: Seq[AttributeReference], - options: Map[String, String]) + options: CaseInsensitiveStringMap) extends LeafNode with MultiInstanceRelation with NamedRelation { import DataSourceV2Implicits._ @@ -48,7 +49,7 @@ case class DataSourceV2Relation( } def newScanBuilder(): ScanBuilder = { - table.asBatchReadable.newScanBuilder(options.toDataSourceOptions) + table.asBatchReadable.newScanBuilder(options) } override def computeStats(): Statistics = { @@ -96,7 +97,7 @@ case class StreamingDataSourceV2Relation( } object DataSourceV2Relation { - def create(table: Table, options: Map[String, String]): DataSourceV2Relation = { + def create(table: Table, options: CaseInsensitiveStringMap): DataSourceV2Relation = { val output = table.schema().toAttributes DataSourceV2Relation(table, output, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 55d7b0a18cbc8..b3a65eeac4dbc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,8 +147,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - AppendDataExec( - r.table.asBatchWritable, r.options.toDataSourceOptions, planLater(query)) :: Nil + AppendDataExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. @@ -158,11 +157,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { }.toArray OverwriteByExpressionExec( - r.table.asBatchWritable, filters, r.options.toDataSourceOptions, planLater(query)) :: Nil + r.table.asBatchWritable, filters, r.options, planLater(query)) :: Nil case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => - OverwritePartitionsDynamicExec(r.table.asBatchWritable, - r.options.toDataSourceOptions, planLater(query)) :: Nil + OverwritePartitionsDynamicExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala index 06c57066aa240..e9c7a1bb749db 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileDataSourceV2.scala @@ -16,10 +16,13 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import com.fasterxml.jackson.databind.ObjectMapper + import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2.TableProvider +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A base interface for data source v2 implementations of the built-in file-based data sources. @@ -35,4 +38,13 @@ trait FileDataSourceV2 extends TableProvider with DataSourceRegister { def fallBackFileFormat: Class[_ <: FileFormat] lazy val sparkSession = SparkSession.active + + protected def getPaths(map: CaseInsensitiveStringMap): Seq[String] = { + val objectMapper = new ObjectMapper() + Option(map.get("paths")).map { pathStr => + objectMapper.readValue(pathStr, classOf[Array[String]]).toSeq + }.getOrElse { + Option(map.get("path")).toSeq + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 21d3e5e29cfb5..08873a3b5a643 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -22,23 +22,27 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap abstract class FileTable( sparkSession: SparkSession, - options: DataSourceOptions, + options: CaseInsensitiveStringMap, + paths: Seq[String], userSpecifiedSchema: Option[StructType]) extends Table with SupportsBatchRead with SupportsBatchWrite { + lazy val fileIndex: PartitioningAwareFileIndex = { - val filePaths = options.paths() - val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) - val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(filePaths, hadoopConf, - checkEmptyGlobPath = true, checkFilesExist = options.checkFilesExist()) + val scalaMap = options.asScala.toMap + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(scalaMap) + // This is an internal config so must be present. + val checkFilesExist = options.get("check_files_exist").toBoolean + val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, + checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) - new InMemoryFileIndex(sparkSession, rootPathsSpecified, - options.asMap().asScala.toMap, userSpecifiedSchema, fileStatusCache) + new InMemoryFileIndex( + sparkSession, rootPathsSpecified, scalaMap, userSpecifiedSchema, fileStatusCache) } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index 75c922424e8ef..e16ee4c460f39 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -33,12 +33,12 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration -abstract class FileWriteBuilder(options: DataSourceOptions) +abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) extends WriteBuilder with SupportsSaveMode { private var schema: StructType = _ private var queryId: String = _ @@ -61,18 +61,17 @@ abstract class FileWriteBuilder(options: DataSourceOptions) override def buildForBatch(): BatchWrite = { validateInputs() - val pathName = options.paths().head - val path = new Path(pathName) + val path = new Path(paths.head) val sparkSession = SparkSession.active - val optionsAsScala = options.asMap().asScala.toMap + val optionsAsScala = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) val job = getJobInstance(hadoopConf, path) val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = pathName) + outputPath = paths.head) lazy val description = - createWriteJobDescription(sparkSession, hadoopConf, job, pathName, optionsAsScala) + createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, optionsAsScala) val fs = path.getFileSystem(hadoopConf) mode match { @@ -127,7 +126,7 @@ abstract class FileWriteBuilder(options: DataSourceOptions) assert(schema != null, "Missing input data schema") assert(queryId != null, "Missing query ID") assert(mode != null, "Missing save mode") - assert(options.paths().length == 1) + assert(paths.length == 1) DataSource.validateSchema(schema) schema.foreach { field => if (!supportsDataType(field.dataType)) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index d7cb2457433b0..51606abdb563a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,8 +31,9 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsBatchWrite} +import org.apache.spark.sql.sources.v2.SupportsBatchWrite import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} /** @@ -53,7 +54,7 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) */ case class AppendDataExec( table: SupportsBatchWrite, - writeOptions: DataSourceOptions, + writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { @@ -81,7 +82,7 @@ case class AppendDataExec( case class OverwriteByExpressionExec( table: SupportsBatchWrite, deleteWhere: Array[Filter], - writeOptions: DataSourceOptions, + writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { private def isTruncate(filters: Array[Filter]): Boolean = { @@ -118,7 +119,7 @@ case class OverwriteByExpressionExec( */ case class OverwritePartitionsDynamicExec( table: SupportsBatchWrite, - writeOptions: DataSourceOptions, + writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { @@ -139,12 +140,9 @@ case class OverwritePartitionsDynamicExec( case class WriteToDataSourceV2Exec( batchWrite: BatchWrite, - query: SparkPlan - ) extends V2TableWriteExec { + query: SparkPlan) extends V2TableWriteExec { - import DataSourceV2Implicits._ - - def writeOptions: DataSourceOptions = Map.empty[String, String].toDataSourceOptions + def writeOptions: CaseInsensitiveStringMap = CaseInsensitiveStringMap.empty() override protected def doExecute(): RDD[InternalRow] = { doWrite(batchWrite) @@ -157,7 +155,7 @@ case class WriteToDataSourceV2Exec( trait BatchWriteHelper { def table: SupportsBatchWrite def query: SparkPlan - def writeOptions: DataSourceOptions + def writeOptions: CaseInsensitiveStringMap def newWriteBuilder(): WriteBuilder = { table.newWriteBuilder(writeOptions) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala index f279af49ba9cf..900c94e937ffc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcDataSourceV2.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.datasources.v2.orc import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.datasources.v2._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, Table} +import org.apache.spark.sql.sources.v2.Table import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcDataSourceV2 extends FileDataSourceV2 { @@ -28,18 +29,20 @@ class OrcDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "orc" - private def getTableName(options: DataSourceOptions): String = { - shortName() + ":" + options.paths().mkString(";") + private def getTableName(paths: Seq[String]): String = { + shortName() + ":" + paths.mkString(";") } - override def getTable(options: DataSourceOptions): Table = { - val tableName = getTableName(options) - OrcTable(tableName, sparkSession, options, None) + override def getTable(options: CaseInsensitiveStringMap): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + OrcTable(tableName, sparkSession, options, paths, None) } - override def getTable(options: DataSourceOptions, schema: StructType): Table = { - val tableName = getTableName(options) - OrcTable(tableName, sparkSession, options, Some(schema)) + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { + val paths = getPaths(options) + val tableName = getTableName(paths) + OrcTable(tableName, sparkSession, options, paths, Some(schema)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index eb27bbd3abeaa..0b153416b7bb0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -26,18 +26,17 @@ import org.apache.spark.sql.execution.datasources.PartitioningAwareFileIndex import org.apache.spark.sql.execution.datasources.orc.OrcFilters import org.apache.spark.sql.execution.datasources.v2.FileScanBuilder import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.Scan import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcScanBuilder( sparkSession: SparkSession, fileIndex: PartitioningAwareFileIndex, schema: StructType, dataSchema: StructType, - options: DataSourceOptions) extends FileScanBuilder(schema) { - lazy val hadoopConf = - sparkSession.sessionState.newHadoopConfWithOptions(options.asMap().asScala.toMap) + options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { + lazy val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap) override def build(): Scan = { OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala index 249df8b8622fb..aac38fb3fa1ff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcTable.scala @@ -21,22 +21,24 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.datasources.orc.OrcUtils import org.apache.spark.sql.execution.datasources.v2.FileTable -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class OrcTable( name: String, sparkSession: SparkSession, - options: DataSourceOptions, + options: CaseInsensitiveStringMap, + paths: Seq[String], userSpecifiedSchema: Option[StructType]) - extends FileTable(sparkSession, options, userSpecifiedSchema) { - override def newScanBuilder(options: DataSourceOptions): OrcScanBuilder = + extends FileTable(sparkSession, options, paths, userSpecifiedSchema) { + + override def newScanBuilder(options: CaseInsensitiveStringMap): OrcScanBuilder = new OrcScanBuilder(sparkSession, fileIndex, schema, dataSchema, options) override def inferSchema(files: Seq[FileStatus]): Option[StructType] = OrcUtils.readSchema(sparkSession, files) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = - new OrcWriteBuilder(options) + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = + new OrcWriteBuilder(options, paths) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 1aec4d872a64d..829ab5fbe1768 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -25,10 +25,12 @@ import org.apache.spark.sql.execution.datasources.{OutputWriter, OutputWriterFac import org.apache.spark.sql.execution.datasources.orc.{OrcFileFormat, OrcOptions, OrcOutputWriter, OrcUtils} import org.apache.spark.sql.execution.datasources.v2.FileWriteBuilder import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) + extends FileWriteBuilder(options, paths) { -class OrcWriteBuilder(options: DataSourceOptions) extends FileWriteBuilder(options) { override def prepareWrite( sqlConf: SQLConf, job: Job, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index bedcb9f8d4e12..fdd80ccaf052e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -95,9 +95,8 @@ class MicroBatchExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") - val dsOptions = new DataSourceOptions(options.asJava) // TODO: operator pushdown. - val scan = table.newScanBuilder(dsOptions).build() + val scan = table.newScanBuilder(options).build() val stream = scan.toMicroBatchStream(metadataPath) StreamingDataSourceV2Relation(output, scan, stream) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 180a23c765dd3..cc441937ce70c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,10 +40,11 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} +import org.apache.spark.sql.sources.v2.SupportsStreamingWrite import org.apache.spark.sql.sources.v2.writer.SupportsTruncate import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{Clock, UninterruptibleThread, Utils} /** States for [[StreamExecution]]'s lifecycle. */ @@ -584,7 +585,7 @@ abstract class StreamExecution( table: SupportsStreamingWrite, options: Map[String, String], inputPlan: LogicalPlan): StreamingWrite = { - val writeBuilder = table.newWriteBuilder(new DataSourceOptions(options.asJava)) + val writeBuilder = table.newWriteBuilder(new CaseInsensitiveStringMap(options.asJava)) .withQueryId(id.toString) .withInputDataSchema(inputPlan.schema) outputMode match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 1b7aa548e6d21..0d7e9ba363d01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{Table, TableProvider} +import org.apache.spark.sql.util.CaseInsensitiveStringMap object StreamingRelation { def apply(dataSource: DataSource): StreamingRelation = { @@ -95,7 +96,7 @@ case class StreamingRelationV2( source: TableProvider, sourceName: String, table: Table, - extraOptions: Map[String, String], + extraOptions: CaseInsensitiveStringMap, output: Seq[Attribute], v1Relation: Option[StreamingRelation])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 923bd749b29b3..dbdfcf8085604 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -24,6 +24,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer.{SupportsTruncate, WriteBuilder} import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ConsoleRelation(override val sqlContext: SQLContext, data: DataFrame) extends BaseRelation { @@ -34,7 +35,7 @@ class ConsoleSinkProvider extends TableProvider with DataSourceRegister with CreatableRelationProvider { - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { ConsoleTable } @@ -62,7 +63,7 @@ object ConsoleTable extends Table with SupportsStreamingWrite { override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index f55a45d2cee73..c8fb53df52598 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -22,7 +22,6 @@ import java.util.concurrent.TimeUnit import java.util.concurrent.atomic.AtomicReference import java.util.function.UnaryOperator -import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.SparkEnv @@ -33,7 +32,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsContinuousRead, SupportsStreamingWrite} +import org.apache.spark.sql.sources.v2.{SupportsContinuousRead, SupportsStreamingWrite} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -71,9 +70,8 @@ class ContinuousExecution( val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") - val dsOptions = new DataSourceOptions(options.asJava) // TODO: operator pushdown. - val scan = table.newScanBuilder(dsOptions).build() + val scan = table.newScanBuilder(options).build() val stream = scan.toContinuousStream(metadataPath) StreamingDataSourceV2Relation(output, scan, stream) }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala index 48ff70f9c9d07..d55f71c7be830 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousRateStreamSource.scala @@ -23,17 +23,13 @@ import org.json4s.jackson.Serialization import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.{RateStreamOffset, ValueRunTimeMsPair} -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ case class RateStreamPartitionOffset( partition: Int, currentValue: Long, currentTimeMs: Long) extends PartitionOffset -class RateStreamContinuousStream( - rowsPerSecond: Long, - numPartitions: Int, - options: DataSourceOptions) extends ContinuousStream { +class RateStreamContinuousStream(rowsPerSecond: Long, numPartitions: Int) extends ContinuousStream { implicit val defaultFormats: DefaultFormats = DefaultFormats val creationTime = System.currentTimeMillis() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala index e7bc71394061e..2263b42870a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousTextSocketSource.scala @@ -34,9 +34,9 @@ import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.streaming.{Offset => _, _} import org.apache.spark.sql.execution.streaming.sources.TextSocketReader -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.RpcUtils @@ -49,7 +49,7 @@ import org.apache.spark.util.RpcUtils * buckets and serves the messages to the executors via a RPC endpoint. */ class TextSocketContinuousStream( - host: String, port: Int, numPartitions: Int, options: DataSourceOptions) + host: String, port: Int, numPartitions: Int, options: CaseInsensitiveStringMap) extends ContinuousStream with Logging { implicit val defaultFormats: DefaultFormats = DefaultFormats diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index e71f81caeb974..df7990c6a652e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2} import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap object MemoryStream { protected val currentBlockId = new AtomicInteger(0) @@ -73,7 +74,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas MemoryStreamTableProvider, "memory", new MemoryStreamTable(this), - Map.empty, + CaseInsensitiveStringMap.empty(), attributes, None)(sqlContext.sparkSession) } @@ -84,7 +85,7 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas // This class is used to indicate the memory stream data source. We don't actually use it, as // memory stream is for test only and we never look it up by name. object MemoryStreamTableProvider extends TableProvider { - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalStateException("MemoryStreamTableProvider should not be used.") } } @@ -96,7 +97,7 @@ class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table override def schema(): StructType = stream.fullSchema() - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MemoryStreamScanBuilder(stream) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala index f2ff30bcf1bef..dbe242784986d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ConsoleWrite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution.streaming.sources import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, SparkSession} import org.apache.spark.sql.catalyst.plans.logical.LocalRelation -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.writer.WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** Common methods used to create writes for the the console sink */ -class ConsoleWrite(schema: StructType, options: DataSourceOptions) +class ConsoleWrite(schema: StructType, options: CaseInsensitiveStringMap) extends StreamingWrite with Logging { // Number of rows to display, by default 20 rows diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index c0ae44a128ca1..44516bbb2a5a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -22,10 +22,11 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table} import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A write-only table for forwarding data into the specified [[ForeachWriter]]. @@ -44,7 +45,7 @@ case class ForeachWriterTable[T]( override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala index a8feed34b96dc..5403eafd54b61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamMicroBatchStream.scala @@ -28,9 +28,9 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{ManualClock, SystemClock} class RateStreamMicroBatchStream( @@ -38,7 +38,7 @@ class RateStreamMicroBatchStream( // The default values here are used in tests. rampUpTimeSeconds: Long = 0, numPartitions: Int = 1, - options: DataSourceOptions, + options: CaseInsensitiveStringMap, checkpointLocation: String) extends MicroBatchStream with Logging { import RateStreamProvider._ @@ -155,7 +155,7 @@ class RateStreamMicroBatchStream( override def toString: String = s"RateStreamV2[rowsPerSecond=$rowsPerSecond, " + s"rampUpTimeSeconds=$rampUpTimeSeconds, " + - s"numPartitions=${options.get(NUM_PARTITIONS).orElse("default")}" + s"numPartitions=${options.getOrDefault(NUM_PARTITIONS, "default")}" } case class RateStreamMicroBatchInputPartition( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3a0082536512d..3d8a90e99b85a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A source that generates increment long values with timestamps. Each generated row has two @@ -43,14 +44,14 @@ import org.apache.spark.sql.types._ class RateStreamProvider extends TableProvider with DataSourceRegister { import RateStreamProvider._ - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { val rowsPerSecond = options.getLong(ROWS_PER_SECOND, 1) if (rowsPerSecond <= 0) { throw new IllegalArgumentException( s"Invalid value '$rowsPerSecond'. The option 'rowsPerSecond' must be positive") } - val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME).orElse(null)) + val rampUpTimeSeconds = Option(options.get(RAMP_UP_TIME)) .map(JavaUtils.timeStringAsSec) .getOrElse(0L) if (rampUpTimeSeconds < 0) { @@ -83,7 +84,7 @@ class RateStreamTable( override def schema(): StructType = RateStreamProvider.SCHEMA - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = RateStreamProvider.SCHEMA @@ -93,7 +94,7 @@ class RateStreamTable( } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { - new RateStreamContinuousStream(rowsPerSecond, numPartitions, options) + new RateStreamContinuousStream(rowsPerSecond, numPartitions) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala index 540131c8de8a1..9168d46493aef 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala @@ -29,7 +29,6 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming.LongOffset -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.{InputPartition, PartitionReader, PartitionReaderFactory} import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset} import org.apache.spark.unsafe.types.UTF8String @@ -39,8 +38,7 @@ import org.apache.spark.unsafe.types.UTF8String * and debugging. This MicroBatchReadSupport will *not* work in production applications due to * multiple reasons, including no support for fault recovery. */ -class TextSocketMicroBatchStream( - host: String, port: Int, numPartitions: Int, options: DataSourceOptions) +class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int) extends MicroBatchStream with Logging { @GuardedBy("this") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index 8ac5bfc307aa3..0adbf1d9b3689 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -30,20 +30,21 @@ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.types.{StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketSourceProvider extends TableProvider with DataSourceRegister with Logging { - private def checkParameters(params: DataSourceOptions): Unit = { + private def checkParameters(params: CaseInsensitiveStringMap): Unit = { logWarning("The socket source should not be used for production applications! " + "It does not support recovery.") - if (!params.get("host").isPresent) { + if (!params.containsKey("host")) { throw new AnalysisException("Set a host to read from with option(\"host\", ...).") } - if (!params.get("port").isPresent) { + if (!params.containsKey("port")) { throw new AnalysisException("Set a port to read from with option(\"port\", ...).") } Try { - params.get("includeTimestamp").orElse("false").toBoolean + params.getBoolean("includeTimestamp", false) } match { case Success(_) => case Failure(_) => @@ -51,10 +52,10 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { checkParameters(options) new TextSocketTable( - options.get("host").get, + options.get("host"), options.getInt("port", -1), options.getInt("numPartitions", SparkSession.active.sparkContext.defaultParallelism), options.getBoolean("includeTimestamp", false)) @@ -77,12 +78,12 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } } - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new ScanBuilder { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = schema() override def toMicroBatchStream(checkpointLocation: String): MicroBatchStream = { - new TextSocketMicroBatchStream(host, port, numPartitions, options) + new TextSocketMicroBatchStream(host, port, numPartitions) } override def toContinuousStream(checkpointLocation: String): ContinuousStream = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 397c5ff0dcb6a..22adceba930fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -31,10 +31,11 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite} +import org.apache.spark.sql.sources.v2.SupportsStreamingWrite import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit @@ -46,7 +47,7 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var needTruncate: Boolean = false private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index 81bffc32027a6..cce8fcd3012bf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRel import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to load a streaming `Dataset` from external storage systems (e.g. file systems, @@ -175,7 +176,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) val table = userSpecifiedSchema match { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) @@ -185,7 +186,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo Dataset.ofRows( sparkSession, StreamingRelationV2( - provider, source, table, options, table.schema.toAttributes, v1Relation)( + provider, source, table, dsOptions, table.schema.toAttributes, v1Relation)( sparkSession)) // fallback to v1 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 984199488fa7b..33d032eb78c2b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.{DataSourceOptions, SupportsStreamingWrite, TableProvider} +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, TableProvider} +import org.apache.spark.sql.util.CaseInsensitiveStringMap /** * Interface used to write a streaming `Dataset` to external storage systems (e.g. file systems, @@ -313,7 +314,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { val sessionOptions = DataSourceV2Utils.extractSessionConfigs( source = provider, conf = df.sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions - val dsOptions = new DataSourceOptions(options.asJava) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) provider.getTable(dsOptions) match { case s: SupportsStreamingWrite => s case _ => createV1Sink() diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java index 2612b6185fd4c..255a9f887878b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaAdvancedDataSourceV2.java @@ -24,19 +24,19 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.Filter; import org.apache.spark.sql.sources.GreaterThan; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaAdvancedDataSourceV2 implements TableProvider { @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new AdvancedScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java index d72ab5338aa8c..699859cfaebe1 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaColumnarDataSourceV2.java @@ -21,11 +21,11 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarBatch; @@ -49,10 +49,10 @@ public PartitionReaderFactory createReaderFactory() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index a513bfb26ef1c..dfbea927e477b 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -22,13 +22,13 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.sources.v2.reader.partitioning.ClusteredDistribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Distribution; import org.apache.spark.sql.sources.v2.reader.partitioning.Partitioning; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaPartitionAwareDataSource implements TableProvider { @@ -54,10 +54,10 @@ public Partitioning outputPartitioning() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java index bbc8492ec4e16..f3755e18b58d5 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaReportStatisticsDataSource.java @@ -19,13 +19,13 @@ import java.util.OptionalLong; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.InputPartition; import org.apache.spark.sql.sources.v2.reader.ScanBuilder; import org.apache.spark.sql.sources.v2.reader.Statistics; import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaReportStatisticsDataSource implements TableProvider { class MyScanBuilder extends JavaSimpleScanBuilder implements SupportsReportStatistics { @@ -54,10 +54,10 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java index 815d57ba94139..3800a94f88898 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSchemaRequiredDataSource.java @@ -17,11 +17,11 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaSchemaRequiredDataSource implements TableProvider { @@ -45,7 +45,7 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(DataSourceOptions options, StructType schema) { + public Table getTable(CaseInsensitiveStringMap options, StructType schema) { return new JavaSimpleBatchTable() { @Override @@ -54,14 +54,14 @@ public StructType schema() { } @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(schema); } }; } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { throw new IllegalArgumentException("requires a user-supplied schema"); } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java index 852c4546df885..7474f36c97f75 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleDataSourceV2.java @@ -17,10 +17,10 @@ package test.org.apache.spark.sql.sources.v2; -import org.apache.spark.sql.sources.v2.DataSourceOptions; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableProvider; import org.apache.spark.sql.sources.v2.reader.*; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; public class JavaSimpleDataSourceV2 implements TableProvider { @@ -36,10 +36,10 @@ public InputPartition[] planInputPartitions() { } @Override - public Table getTable(DataSourceOptions options) { + public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { @Override - public ScanBuilder newScanBuilder(DataSourceOptions options) { + public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); } }; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala index cccd8e9ee8bd1..034454d21d7ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcFilterSuite.scala @@ -32,7 +32,6 @@ import org.apache.spark.sql.execution.datasources.{DataSourceStrategy, HadoopFsR import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.execution.datasources.v2.orc.OrcTable import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -58,7 +57,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { case PhysicalOperation(_, filters, DataSourceV2Relation(orcTable: OrcTable, _, options)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava)) + val scanBuilder = orcTable.newScanBuilder(options) scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) val pushedFilters = scanBuilder.pushedFilters() assert(pushedFilters.nonEmpty, "No filter is pushed down") @@ -102,7 +101,7 @@ class OrcFilterSuite extends OrcTest with SharedSQLContext { case PhysicalOperation(_, filters, DataSourceV2Relation(orcTable: OrcTable, _, options)) => assert(filters.nonEmpty, "No filter is analyzed from the given query") - val scanBuilder = orcTable.newScanBuilder(new DataSourceOptions(options.asJava)) + val scanBuilder = orcTable.newScanBuilder(options) scanBuilder.pushFilters(filters.flatMap(DataSourceStrategy.translateFilter).toArray) val pushedFilters = scanBuilder.pushedFilters() if (noneSupported) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index d0418f893143e..c04f6e3f255cb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ManualClock class RateStreamProviderSuite extends StreamTest { @@ -135,7 +135,7 @@ class RateStreamProviderSuite extends StreamTest { withTempDir { temp => val stream = new RateStreamMicroBatchStream( rowsPerSecond = 100, - options = new DataSourceOptions(Map("useManualClock" -> "true").asJava), + options = new CaseInsensitiveStringMap(Map("useManualClock" -> "true").asJava), checkpointLocation = temp.getCanonicalPath) stream.clock.asInstanceOf[ManualClock].advance(100000) val startOffset = stream.initialOffset() @@ -154,7 +154,7 @@ class RateStreamProviderSuite extends StreamTest { withTempDir { temp => val stream = new RateStreamMicroBatchStream( rowsPerSecond = 20, - options = DataSourceOptions.empty(), + options = CaseInsensitiveStringMap.empty(), checkpointLocation = temp.getCanonicalPath) val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) val readerFactory = stream.createReaderFactory() @@ -173,7 +173,7 @@ class RateStreamProviderSuite extends StreamTest { val stream = new RateStreamMicroBatchStream( rowsPerSecond = 33, numPartitions = 11, - options = DataSourceOptions.empty(), + options = CaseInsensitiveStringMap.empty(), checkpointLocation = temp.getCanonicalPath) val partitions = stream.planInputPartitions(LongOffset(0L), LongOffset(1L)) val readerFactory = stream.createReaderFactory() @@ -309,8 +309,7 @@ class RateStreamProviderSuite extends StreamTest { } test("continuous data") { - val stream = new RateStreamContinuousStream( - rowsPerSecond = 20, numPartitions = 2, options = DataSourceOptions.empty()) + val stream = new RateStreamContinuousStream(rowsPerSecond = 20, numPartitions = 2) val partitions = stream.planInputPartitions(stream.initialOffset) val readerFactory = stream.createContinuousReaderFactory() assert(partitions.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 33c65d784fba6..6a7c54176c347 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -35,11 +35,11 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.DataSourceOptions import org.apache.spark.sql.sources.v2.reader.streaming.Offset import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.CaseInsensitiveStringMap class TextSocketStreamSuite extends StreamTest with SharedSQLContext with BeforeAndAfterEach { @@ -176,13 +176,13 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before test("params not given") { val provider = new TextSocketSourceProvider intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map.empty[String, String].asJava)) + provider.getTable(CaseInsensitiveStringMap.empty()) } intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map("host" -> "localhost").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("host" -> "localhost").asJava)) } intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(Map("port" -> "1234").asJava)) + provider.getTable(new CaseInsensitiveStringMap(Map("port" -> "1234").asJava)) } } @@ -190,7 +190,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before val provider = new TextSocketSourceProvider val params = Map("host" -> "localhost", "port" -> "1234", "includeTimestamp" -> "fasle") intercept[AnalysisException] { - provider.getTable(new DataSourceOptions(params.asJava)) + provider.getTable(new CaseInsensitiveStringMap(params.asJava)) } } @@ -201,7 +201,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before StructField("area", StringType) :: Nil) val params = Map("host" -> "localhost", "port" -> "1234") val exception = intercept[UnsupportedOperationException] { - provider.getTable(new DataSourceOptions(params.asJava), userSpecifiedSchema) + provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) } assert(exception.getMessage.contains( "socket source does not support user-specified schema")) @@ -299,7 +299,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = DataSourceOptions.empty()) + options = CaseInsensitiveStringMap.empty()) val partitions = stream.planInputPartitions(stream.initialOffset()) assert(partitions.length == 2) @@ -351,7 +351,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = DataSourceOptions.empty()) + options = CaseInsensitiveStringMap.empty()) stream.startOffset = TextSocketOffset(List(5, 5)) assertThrows[IllegalStateException] { @@ -367,7 +367,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before host = "localhost", port = serverThread.port, numPartitions = 2, - options = new DataSourceOptions(Map("includeTimestamp" -> "true").asJava)) + options = new CaseInsensitiveStringMap(Map("includeTimestamp" -> "true").asJava)) val partitions = stream.planInputPartitions(stream.initialOffset()) assert(partitions.size == 2) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index b8572448f736e..705559d099bec 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -34,6 +34,7 @@ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.sql.vectorized.ColumnarBatch class DataSourceV2Suite extends QueryTest with SharedSQLContext { @@ -349,7 +350,7 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { val options = df.queryExecution.optimizedPlan.collectFirst { case d: DataSourceV2Relation => d.options }.get - assert(options.get(optionName).get == "false") + assert(options.get(optionName) === "false") } } @@ -437,8 +438,8 @@ class SimpleSinglePartitionSource extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -454,8 +455,8 @@ class SimpleDataSourceV2 extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -463,8 +464,8 @@ class SimpleDataSourceV2 extends TableProvider { class AdvancedDataSourceV2 extends TableProvider { - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new AdvancedScanBuilder() } } @@ -559,16 +560,16 @@ class SchemaRequiredDataSource extends TableProvider { override def readSchema(): StructType = schema } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { throw new IllegalArgumentException("requires a user-supplied schema") } - override def getTable(options: DataSourceOptions, schema: StructType): Table = { + override def getTable(options: CaseInsensitiveStringMap, schema: StructType): Table = { val userGivenSchema = schema new SimpleBatchTable { override def schema(): StructType = userGivenSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(userGivenSchema) } } @@ -588,8 +589,8 @@ class ColumnarDataSourceV2 extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -659,8 +660,8 @@ class PartitionAwareDataSource extends TableProvider { override def outputPartitioning(): Partitioning = new MyPartitioning } - override def getTable(options: DataSourceOptions): Table = new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def getTable(options: CaseInsensitiveStringMap): Table = new SimpleBatchTable { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder() } } @@ -699,7 +700,7 @@ class SchemaReadAttemptException(m: String) extends RuntimeException(m) class SimpleWriteOnlyDataSource extends SimpleWritableDataSource { - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) { override def schema(): StructType = { throw new SchemaReadAttemptException("schema should not be read.") @@ -725,9 +726,9 @@ class ReportStatisticsDataSource extends TableProvider { } } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new SimpleBatchTable { - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index fd19a48497fe6..f9f9db35ac2dd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -18,13 +18,14 @@ package org.apache.spark.sql.sources.v2 import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.execution.datasources.FileFormat -import org.apache.spark.sql.execution.datasources.parquet.{ParquetFileFormat, ParquetTest} +import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat import org.apache.spark.sql.execution.datasources.v2.FileDataSourceV2 import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.ScanBuilder import org.apache.spark.sql.sources.v2.writer.WriteBuilder import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -32,7 +33,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyReadOnlyFileTable } } @@ -42,7 +43,7 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead { override def schema(): StructType = StructType(Nil) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { throw new AnalysisException("Dummy file reader") } } @@ -53,7 +54,7 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { override def shortName(): String = "parquet" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new DummyWriteOnlyFileTable } } @@ -63,7 +64,7 @@ class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { override def schema(): StructType = StructType(Nil) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = throw new AnalysisException("Dummy file writer") } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index c56a54598cd4c..160354520e432 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -25,12 +25,12 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.internal.config.SPECULATION_ENABLED import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration /** @@ -141,22 +141,24 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } } - class MyTable(options: DataSourceOptions) extends SimpleBatchTable with SupportsBatchWrite { - private val path = options.get("path").get() + class MyTable(options: CaseInsensitiveStringMap) + extends SimpleBatchTable with SupportsBatchWrite { + + private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration override def schema(): StructType = tableSchema - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = { + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MyScanBuilder(new Path(path).toUri.toString, conf) } - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new MyWriteBuilder(path) } } - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new MyTable(options) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 553b48398c9ba..13bb686fbd3b9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.sources.v2.writer.{WriteBuilder, WriterCommitMessage import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery, StreamTest, Trigger} import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.Utils class FakeDataStream extends MicroBatchStream with ContinuousStream { @@ -76,19 +77,19 @@ class FakeWriteBuilder extends WriteBuilder with StreamingWrite { trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) - override def newScanBuilder(options: DataSourceOptions): ScanBuilder = new FakeScanBuilder + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) - override def newWriteBuilder(options: DataSourceOptions): WriteBuilder = { + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new FakeWriteBuilder } } @@ -101,7 +102,7 @@ class FakeReadMicroBatchOnly override def keyPrefix: String = shortName() - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options new FakeMicroBatchReadTable {} } @@ -115,7 +116,7 @@ class FakeReadContinuousOnly override def keyPrefix: String = shortName() - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options new FakeContinuousReadTable {} } @@ -124,7 +125,7 @@ class FakeReadContinuousOnly class FakeReadBothModes extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-microbatch-continuous" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table with FakeMicroBatchReadTable with FakeContinuousReadTable {} } } @@ -132,7 +133,7 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { class FakeReadNeitherMode extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-neither-mode" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -148,7 +149,7 @@ class FakeWriteOnly override def keyPrefix: String = shortName() - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { LastWriteOptions.options = options new Table with FakeStreamingWriteTable { override def name(): String = "fake" @@ -159,7 +160,7 @@ class FakeWriteOnly class FakeNoWrite extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-write-neither-mode" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -186,7 +187,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister override def shortName(): String = "fake-write-v1-fallback" - override def getTable(options: DataSourceOptions): Table = { + override def getTable(options: CaseInsensitiveStringMap): Table = { new Table with FakeStreamingWriteTable { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) @@ -195,7 +196,7 @@ class FakeWriteSupportProviderV1Fallback extends DataSourceRegister } object LastReadOptions { - var options: DataSourceOptions = _ + var options: CaseInsensitiveStringMap = _ def clear(): Unit = { options = null @@ -203,7 +204,7 @@ object LastReadOptions { } object LastWriteOptions { - var options: DataSourceOptions = _ + var options: CaseInsensitiveStringMap = _ def clear(): Unit = { options = null @@ -320,8 +321,8 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => eventually(timeout(streamingTimeout)) { // Write options should not be set. - assert(LastWriteOptions.options.getBoolean(readOptionName, false) == false) - assert(LastReadOptions.options.getBoolean(readOptionName, false) == true) + assert(!LastWriteOptions.options.containsKey(readOptionName)) + assert(LastReadOptions.options.getBoolean(readOptionName, false)) } } } @@ -331,8 +332,8 @@ class StreamingDataSourceV2Suite extends StreamTest { testPositiveCaseWithQuery(readSource, writeSource, trigger) { _ => eventually(timeout(streamingTimeout)) { // Read options should not be set. - assert(LastReadOptions.options.getBoolean(writeOptionName, false) == false) - assert(LastWriteOptions.options.getBoolean(writeOptionName, false) == true) + assert(!LastReadOptions.options.containsKey(writeOptionName)) + assert(LastWriteOptions.options.getBoolean(writeOptionName, false)) } } } @@ -351,10 +352,10 @@ class StreamingDataSourceV2Suite extends StreamTest { for ((read, write, trigger) <- cases) { testQuietly(s"stream with read format $read, write format $write, trigger $trigger") { val sourceTable = DataSource.lookupDataSource(read, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() - .newInstance().asInstanceOf[TableProvider].getTable(DataSourceOptions.empty()) + .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) (sourceTable, sinkTable, trigger) match { // Valid microbatch queries. From 8f9c5acfb5502fc97b16a3861b9cf8dcd45672d8 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Mon, 18 Mar 2019 18:25:11 +0800 Subject: [PATCH 18/70] [SPARK-26811][SQL] Add capabilities to v2.Table This adds a new method, `capabilities` to `v2.Table` that returns a set of `TableCapability`. Capabilities are used to fail queries during analysis checks, `V2WriteSupportCheck`, when the table does not support operations, like truncation. Existing tests for regressions, added new analysis suite, `V2WriteSupportCheckSuite`, for new capability checks. Closes #24012 from rdblue/SPARK-26811-add-capabilities. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/kafka010/KafkaSourceProvider.scala | 4 +- .../sql/sources/v2/SupportsBatchRead.java | 34 ---- .../sql/sources/v2/SupportsBatchWrite.java | 33 ---- .../spark/sql/sources/v2/SupportsRead.java | 2 +- .../spark/sql/sources/v2/SupportsWrite.java | 2 +- .../apache/spark/sql/sources/v2/Table.java | 14 +- .../spark/sql/sources/v2/TableCapability.java | 69 ++++++++ .../spark/sql/sources/v2/reader/Scan.java | 7 +- .../sql/sources/v2/writer/WriteBuilder.java | 4 +- .../apache/spark/sql/DataFrameReader.scala | 6 +- .../apache/spark/sql/DataFrameWriter.scala | 7 +- .../datasources/noop/NoopDataSource.scala | 7 +- .../v2/DataSourceV2Implicits.scala | 18 ++- .../datasources/v2/DataSourceV2Relation.scala | 2 +- .../datasources/v2/DataSourceV2Strategy.scala | 6 +- .../execution/datasources/v2/FileTable.scala | 11 +- .../datasources/v2/V2WriteSupportCheck.scala | 56 +++++++ .../v2/WriteToDataSourceV2Exec.scala | 12 +- .../sql/execution/streaming/console.scala | 5 + .../sql/execution/streaming/memory.scala | 4 + .../sources/ForeachWriterTable.scala | 7 +- .../sources/RateStreamProvider.scala | 5 + .../sources/TextSocketSourceProvider.scala | 5 +- .../streaming/sources/memoryV2.scala | 6 +- .../internal/BaseSessionStateBuilder.scala | 2 + .../sql/sources/v2/JavaSimpleBatchTable.java | 17 +- .../sql/sources/v2/DataSourceV2Suite.scala | 8 +- .../v2/FileDataSourceV2FallBackSuite.scala | 12 +- .../sources/v2/SimpleWritableDataSource.scala | 7 +- .../sources/v2/V2WriteSupportCheckSuite.scala | 149 ++++++++++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 8 + .../sql/hive/HiveSessionStateBuilder.scala | 2 + 32 files changed, 417 insertions(+), 114 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java create mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 12f09afdb238d..4af263e1a7f2e 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Locale, UUID} +import java.util.{Collections, Locale, UUID} import scala.collection.JavaConverters._ @@ -358,6 +358,8 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister override def schema(): StructType = KafkaOffsetReader.kafkaSchema + override def capabilities(): ju.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java deleted file mode 100644 index ea7c5d2b108f0..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchRead.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch scan. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} - * that builds {@link Scan} with {@link Scan#toBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java deleted file mode 100644 index 09e23f84fd6bf..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsBatchWrite.java +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.writer.WriteBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports batch write. - *

- * If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)} must return a - * {@link WriteBuilder} with {@link WriteBuilder#buildForBatch()} implemented. - *

- */ -@Evolving -public interface SupportsBatchWrite extends SupportsWrite {} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java index 14990effeda37..67fc72e070dc9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -26,7 +26,7 @@ * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch, * micro-batch, or continuous processing. */ -interface SupportsRead extends Table { +public interface SupportsRead extends Table { /** * Returns a {@link ScanBuilder} which can be used to build a {@link Scan}. Spark will call this diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java index f0d8e44f15287..b215963868217 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java @@ -26,7 +26,7 @@ * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write * for batch or streaming. */ -interface SupportsWrite extends Table { +public interface SupportsWrite extends Table { /** * Returns a {@link WriteBuilder} which can be used to create {@link BatchWrite}. Spark will call diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 08664859b8de2..78f979a2a9a44 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -20,16 +20,15 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.types.StructType; +import java.util.Set; + /** * An interface representing a logical structured data set of a data source. For example, the * implementation can be a directory on the file system, a topic of Kafka, or a table in the * catalog, etc. *

- * This interface can mixin the following interfaces to support different operations: - *

- *
    - *
  • {@link SupportsBatchRead}: this table can be read in batch queries.
  • - *
+ * This interface can mixin the following interfaces to support different operations, like + * {@code SupportsRead}. */ @Evolving public interface Table { @@ -45,4 +44,9 @@ public interface Table { * empty schema can be returned here. */ StructType schema(); + + /** + * Returns the set of capabilities for this table. + */ + Set capabilities(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java new file mode 100644 index 0000000000000..8d3fdcd694e2c --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2; + +import org.apache.spark.annotation.Experimental; + +/** + * Capabilities that can be provided by a {@link Table} implementation. + *

+ * Tables use {@link Table#capabilities()} to return a set of capabilities. Each capability signals + * to Spark that the table supports a feature identified by the capability. For example, returning + * {@code BATCH_READ} allows Spark to read from the table using a batch scan. + */ +@Experimental +public enum TableCapability { + /** + * Signals that the table supports reads in batch execution mode. + */ + BATCH_READ, + + /** + * Signals that the table supports append writes in batch execution mode. + *

+ * Tables that return this capability must support appending data and may also support additional + * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and + * {@link #OVERWRITE_DYNAMIC}. + */ + BATCH_WRITE, + + /** + * Signals that the table can be truncated in a write operation. + *

+ * Truncating a table removes all existing rows. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsTruncate}. + */ + TRUNCATE, + + /** + * Signals that the table can replace existing data that matches a filter with appended data in + * a write operation. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}. + */ + OVERWRITE_BY_FILTER, + + /** + * Signals that the table can dynamically replace existing data partitions with appended data in + * a write operation. + *

+ * See {@link org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. + */ + OVERWRITE_DYNAMIC +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index 25ab06eee42e0..e97d0548c66ff 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -21,7 +21,6 @@ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousStream; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; import org.apache.spark.sql.sources.v2.SupportsContinuousRead; import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; @@ -33,8 +32,8 @@ * This logical representation is shared between batch scan, micro-batch streaming scan and * continuous streaming scan. Data sources must implement the corresponding methods in this * interface, to match what the table promises to support. For example, {@link #toBatch()} must be - * implemented, if the {@link Table} that creates this {@link Scan} implements - * {@link SupportsBatchRead}. + * implemented, if the {@link Table} that creates this {@link Scan} returns BATCH_READ support in + * its {@link Table#capabilities()}. *

*/ @Evolving @@ -62,7 +61,7 @@ default String description() { /** * Returns the physical representation of this scan for batch query. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchRead}. + * {@link Table} that creates this returns batch read support in its {@link Table#capabilities()}. * * @throws UnsupportedOperationException */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index 07529fe1dee91..e08d34fbf453e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.writer; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.SupportsBatchWrite; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite; import org.apache.spark.sql.types.StructType; @@ -58,7 +57,8 @@ default WriteBuilder withInputDataSchema(StructType schema) { /** * Returns a {@link BatchWrite} to write data to batch source. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this scan implements {@link SupportsBatchWrite}. + * {@link Table} that creates this write returns BATCH_WRITE support in its + * {@link Table#capabilities()}. * * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode, * to indicate that no writing is needed. We can clean it up after removing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index 2235217b9c1ec..b6d347e6415a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -37,8 +37,9 @@ import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.datasources.csv._ import org.apache.spark.sql.execution.datasources.jdbc._ import org.apache.spark.sql.execution.datasources.json.TextInputJsonDataSource -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, FileTable} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2} import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.unsafe.types.UTF8String @@ -221,8 +222,9 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { - case _: SupportsBatchRead => + case _: SupportsRead if table.supports(BATCH_READ) => Dataset.ofRows(sparkSession, DataSourceV2Relation.create(table, dsOptions)) case _ => loadV1Source(paths: _*) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 9f766cfccdf93..289efde2a7f00 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, Logi import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -265,8 +266,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val checkFilesExistsOption = "check_files_exist" -> "false" val options = sessionOptions ++ extraOptions + checkFilesExistsOption val dsOptions = new CaseInsensitiveStringMap(options.asJava) + + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { - case table: SupportsBatchWrite => + case table: SupportsWrite if table.supports(BATCH_WRITE) => lazy val relation = DataSourceV2Relation.create(table, dsOptions) mode match { case SaveMode.Append => @@ -274,7 +277,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { AppendData.byName(relation, df.logicalPlan) } - case SaveMode.Overwrite => + case SaveMode.Overwrite if table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER) => // truncate the table runCommand(df.sparkSession, "save") { OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index aa2a5e9a06fbd..96a78d3a0da20 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -17,6 +17,10 @@ package org.apache.spark.sql.execution.datasources.noop +import java.util + +import scala.collection.JavaConverters._ + import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister @@ -35,10 +39,11 @@ class NoopDataSource extends TableProvider with DataSourceRegister { override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } -private[noop] object NoopTable extends Table with SupportsBatchWrite with SupportsStreamingWrite { +private[noop] object NoopTable extends Table with SupportsWrite with SupportsStreamingWrite { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava } private[noop] object NoopWriteBuilder extends WriteBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala index 2081af35ce2d1..eed69cdc8cac6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Implicits.scala @@ -18,26 +18,30 @@ package org.apache.spark.sql.execution.datasources.v2 import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} object DataSourceV2Implicits { implicit class TableHelper(table: Table) { - def asBatchReadable: SupportsBatchRead = { + def asReadable: SupportsRead = { table match { - case support: SupportsBatchRead => + case support: SupportsRead => support case _ => - throw new AnalysisException(s"Table does not support batch reads: ${table.name}") + throw new AnalysisException(s"Table does not support reads: ${table.name}") } } - def asBatchWritable: SupportsBatchWrite = { + def asWritable: SupportsWrite = { table match { - case support: SupportsBatchWrite => + case support: SupportsWrite => support case _ => - throw new AnalysisException(s"Table does not support batch writes: ${table.name}") + throw new AnalysisException(s"Table does not support writes: ${table.name}") } } + + def supports(capability: TableCapability): Boolean = table.capabilities.contains(capability) + + def supportsAny(capabilities: TableCapability*): Boolean = capabilities.exists(supports) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 17407827d0564..411995718603c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -49,7 +49,7 @@ case class DataSourceV2Relation( } def newScanBuilder(): ScanBuilder = { - table.asBatchReadable.newScanBuilder(options) + table.asReadable.newScanBuilder(options) } override def computeStats(): Statistics = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b3a65eeac4dbc..b4b21e1b6d69e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -147,7 +147,7 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil case AppendData(r: DataSourceV2Relation, query, _) => - AppendDataExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil + AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil case OverwriteByExpression(r: DataSourceV2Relation, deleteExpr, query, _) => // fail if any filter cannot be converted. correctness depends on removing all matching data. @@ -157,10 +157,10 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { }.toArray OverwriteByExpressionExec( - r.table.asBatchWritable, filters, r.options, planLater(query)) :: Nil + r.table.asWritable, filters, r.options, planLater(query)) :: Nil case OverwritePartitionsDynamic(r: DataSourceV2Relation, query, _) => - OverwritePartitionsDynamicExec(r.table.asBatchWritable, r.options, planLater(query)) :: Nil + OverwritePartitionsDynamicExec(r.table.asWritable, r.options, planLater(query)) :: Nil case WriteToContinuousDataSource(writer, query) => WriteToContinuousDataSourceExec(writer, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index 08873a3b5a643..c00e65b07312f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -22,7 +22,8 @@ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.sources.v2.{SupportsBatchRead, SupportsBatchWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -31,7 +32,7 @@ abstract class FileTable( options: CaseInsensitiveStringMap, paths: Seq[String], userSpecifiedSchema: Option[StructType]) - extends Table with SupportsBatchRead with SupportsBatchWrite { + extends Table with SupportsRead with SupportsWrite { lazy val fileIndex: PartitioningAwareFileIndex = { val scalaMap = options.asScala.toMap @@ -58,6 +59,8 @@ abstract class FileTable( fileIndex.partitionSchema, caseSensitive)._1 } + override def capabilities(): java.util.Set[TableCapability] = FileTable.CAPABILITIES + /** * When possible, this method should return the schema of the given `files`. When the format * does not support inference, or no valid files are given should return None. In these cases @@ -65,3 +68,7 @@ abstract class FileTable( */ def inferSchema(files: Seq[FileStatus]): Option[StructType] } + +object FileTable { + private val CAPABILITIES = Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala new file mode 100644 index 0000000000000..cf77998c122f8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2WriteSupportCheck.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.expressions.Literal +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.BooleanType + +object V2WriteSupportCheck extends (LogicalPlan => Unit) { + import DataSourceV2Implicits._ + + def failAnalysis(msg: String): Unit = throw new AnalysisException(msg) + + override def apply(plan: LogicalPlan): Unit = plan foreach { + case AppendData(rel: DataSourceV2Relation, _, _) if !rel.table.supports(BATCH_WRITE) => + failAnalysis(s"Table does not support append in batch mode: ${rel.table}") + + case OverwritePartitionsDynamic(rel: DataSourceV2Relation, _, _) + if !rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_DYNAMIC) => + failAnalysis(s"Table does not support dynamic overwrite in batch mode: ${rel.table}") + + case OverwriteByExpression(rel: DataSourceV2Relation, expr, _, _) => + expr match { + case Literal(true, BooleanType) => + if (!rel.table.supports(BATCH_WRITE) || + !rel.table.supportsAny(TRUNCATE, OVERWRITE_BY_FILTER)) { + failAnalysis( + s"Table does not support truncate in batch mode: ${rel.table}") + } + case _ => + if (!rel.table.supports(BATCH_WRITE) || !rel.table.supports(OVERWRITE_BY_FILTER)) { + failAnalysis(s"Table does not support overwrite expression ${expr.sql} " + + s"in batch mode: ${rel.table}") + } + } + + case _ => // OK + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 51606abdb563a..607f2fa0f82c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -31,7 +31,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} -import org.apache.spark.sql.sources.v2.SupportsBatchWrite +import org.apache.spark.sql.sources.v2.SupportsWrite import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -53,7 +53,7 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) * Rows in the output data set are appended. */ case class AppendDataExec( - table: SupportsBatchWrite, + table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -80,7 +80,7 @@ case class AppendDataExec( * AlwaysTrue to delete all rows. */ case class OverwriteByExpressionExec( - table: SupportsBatchWrite, + table: SupportsWrite, deleteWhere: Array[Filter], writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -101,7 +101,7 @@ case class OverwriteByExpressionExec( builder.overwrite(deleteWhere).buildForBatch() case _ => - throw new SparkException(s"Table does not support dynamic partition overwrite: $table") + throw new SparkException(s"Table does not support overwrite by expression: $table") } doWrite(batchWrite) @@ -118,7 +118,7 @@ case class OverwriteByExpressionExec( * are not modified. */ case class OverwritePartitionsDynamicExec( - table: SupportsBatchWrite, + table: SupportsWrite, writeOptions: CaseInsensitiveStringMap, query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { @@ -153,7 +153,7 @@ case class WriteToDataSourceV2Exec( * Helper for physical plans that build batch writes. */ trait BatchWriteHelper { - def table: SupportsBatchWrite + def table: SupportsWrite def query: SparkPlan def writeOptions: CaseInsensitiveStringMap diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index dbdfcf8085604..884b92ae9421c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming +import java.util +import java.util.Collections + import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite import org.apache.spark.sql.sources.{BaseRelation, CreatableRelationProvider, DataSourceRegister} @@ -63,6 +66,8 @@ object ConsoleTable extends Table with SupportsStreamingWrite { override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index df7990c6a652e..bfa9c09985503 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming +import java.util +import java.util.Collections import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy @@ -97,6 +99,8 @@ class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table override def schema(): StructType = stream.fullSchema() + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MemoryStreamScanBuilder(stream) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 44516bbb2a5a1..807e0b12c6278 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -17,12 +17,15 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table} +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -45,6 +48,8 @@ case class ForeachWriterTable[T]( override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var inputSchema: StructType = _ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 3d8a90e99b85a..08aea75de2b5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession import org.apache.spark.sql.execution.streaming.continuous.RateStreamContinuousStream @@ -84,6 +87,8 @@ class RateStreamTable( override def schema(): StructType = RateStreamProvider.SCHEMA + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = RateStreamProvider.SCHEMA diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index 0adbf1d9b3689..c0292acdf1044 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.streaming.sources import java.text.SimpleDateFormat -import java.util.Locale +import java.util +import java.util.{Collections, Locale} import scala.util.{Failure, Success, Try} @@ -78,6 +79,8 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } } + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { override def readSchema(): StructType = schema() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 22adceba930fb..8eb5de0f640a4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.streaming.sources +import java.util +import java.util.Collections import javax.annotation.concurrent.GuardedBy import scala.collection.mutable @@ -31,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.SupportsStreamingWrite +import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, TableCapability} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -47,6 +49,8 @@ class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Loggi override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { private var needTruncate: Boolean = false diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index a605dc640dc96..f05aa5113e03a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -172,6 +173,7 @@ abstract class BaseSessionStateBuilder( PreWriteCheck +: PreReadCheck +: HiveOnlyCheck +: + V2WriteSupportCheck +: customCheckRules } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java index cb5954d5a6211..9b0eb610a206f 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaSimpleBatchTable.java @@ -18,15 +18,23 @@ package test.org.apache.spark.sql.sources.v2; import java.io.IOException; +import java.util.Arrays; +import java.util.HashSet; +import java.util.Set; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; -import org.apache.spark.sql.sources.v2.SupportsBatchRead; +import org.apache.spark.sql.sources.v2.SupportsRead; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableCapability; import org.apache.spark.sql.sources.v2.reader.*; import org.apache.spark.sql.types.StructType; -abstract class JavaSimpleBatchTable implements Table, SupportsBatchRead { +abstract class JavaSimpleBatchTable implements Table, SupportsRead { + private static final Set CAPABILITIES = new HashSet<>(Arrays.asList( + TableCapability.BATCH_READ, + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE)); @Override public StructType schema() { @@ -37,6 +45,11 @@ public StructType schema() { public String name() { return this.getClass().toString(); } + + @Override + public Set capabilities() { + return CAPABILITIES; + } } abstract class JavaSimpleScanBuilder implements ScanBuilder, Scan, Batch { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 705559d099bec..587cfa9bd6647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -18,8 +18,11 @@ package org.apache.spark.sql.sources.v2 import java.io.File +import java.util import java.util.OptionalLong +import scala.collection.JavaConverters._ + import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException @@ -30,6 +33,7 @@ import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector import org.apache.spark.sql.functions._ import org.apache.spark.sql.sources.{Filter, GreaterThan} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.partitioning.{ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.test.SharedSQLContext @@ -411,11 +415,13 @@ object SimpleReaderFactory extends PartitionReaderFactory { } } -abstract class SimpleBatchTable extends Table with SupportsBatchRead { +abstract class SimpleBatchTable extends Table with SupportsRead { override def schema(): StructType = new StructType().add("i", "int").add("j", "int") override def name(): String = this.getClass.toString + + override def capabilities(): util.Set[TableCapability] = Set(BATCH_READ).asJava } abstract class SimpleScanBuilder extends ScanBuilder diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index f9f9db35ac2dd..e019dbfe3f512 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.sources.v2 +import scala.collection.JavaConverters._ + import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.execution.datasources.FileFormat import org.apache.spark.sql.execution.datasources.parquet.ParquetFileFormat @@ -38,7 +40,7 @@ class DummyReadOnlyFileDataSourceV2 extends FileDataSourceV2 { } } -class DummyReadOnlyFileTable extends Table with SupportsBatchRead { +class DummyReadOnlyFileTable extends Table with SupportsRead { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) @@ -46,6 +48,9 @@ class DummyReadOnlyFileTable extends Table with SupportsBatchRead { override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { throw new AnalysisException("Dummy file reader") } + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_READ).asJava } class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -59,13 +64,16 @@ class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { } } -class DummyWriteOnlyFileTable extends Table with SupportsBatchWrite { +class DummyWriteOnlyFileTable extends Table with SupportsWrite { override def name(): String = "dummy" override def schema(): StructType = StructType(Nil) override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = throw new AnalysisException("Dummy file writer") + + override def capabilities(): java.util.Set[TableCapability] = + Set(TableCapability.BATCH_WRITE).asJava } class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index 160354520e432..edebb0b62b29c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.sources.v2 import java.io.{BufferedReader, InputStreamReader, IOException} +import java.util import scala.collection.JavaConverters._ @@ -27,6 +28,7 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.types.StructType @@ -142,7 +144,7 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { } class MyTable(options: CaseInsensitiveStringMap) - extends SimpleBatchTable with SupportsBatchWrite { + extends SimpleBatchTable with SupportsWrite { private val path = options.get("path") private val conf = SparkContext.getActive.get.hadoopConfiguration @@ -156,6 +158,9 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new MyWriteBuilder(path) } + + override def capabilities(): util.Set[TableCapability] = + Set(BATCH_READ, BATCH_WRITE, TRUNCATE).asJava } override def getTable(options: CaseInsensitiveStringMap): Table = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala new file mode 100644 index 0000000000000..1d76ee34a0e0b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/V2WriteSupportCheckSuite.scala @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, NamedRelation} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, EqualTo, Literal} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, OverwriteByExpression, OverwritePartitionsDynamic} +import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, V2WriteSupportCheck} +import org.apache.spark.sql.sources.v2.TableCapability._ +import org.apache.spark.sql.types.{LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2WriteSupportCheckSuite extends AnalysisTest { + + test("AppendData: check missing capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(), CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support append in batch mode")) + } + + test("AppendData: check correct capabilities") { + val plan = AppendData.byName( + DataSourceV2Relation.create(CapabilityTable(BATCH_WRITE), CaseInsensitiveStringMap.empty), + TestRelation) + + V2WriteSupportCheck.apply(plan) + } + + test("Truncate: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(TRUNCATE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support truncate in batch mode")) + } + } + + test("Truncate: check correct capabilities") { + Seq(CapabilityTable(BATCH_WRITE, TRUNCATE), + CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + Literal(true)) + + V2WriteSupportCheck.apply(plan) + } + } + + test("OverwriteByExpression: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_BY_FILTER)).foreach { table => + + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + val exc = intercept[AnalysisException]{ + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains( + "does not support overwrite expression (`x` = 5) in batch mode")) + } + } + + test("OverwriteByExpression: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_BY_FILTER) + val plan = OverwriteByExpression.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation, + EqualTo(AttributeReference("x", LongType)(), Literal(5))) + + V2WriteSupportCheck.apply(plan) + } + + test("OverwritePartitionsDynamic: check missing capabilities") { + Seq(CapabilityTable(), + CapabilityTable(BATCH_WRITE), + CapabilityTable(OVERWRITE_DYNAMIC)).foreach { table => + + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + val exc = intercept[AnalysisException] { + V2WriteSupportCheck.apply(plan) + } + + assert(exc.getMessage.contains("does not support dynamic overwrite in batch mode")) + } + } + + test("OverwritePartitionsDynamic: check correct capabilities") { + val table = CapabilityTable(BATCH_WRITE, OVERWRITE_DYNAMIC) + val plan = OverwritePartitionsDynamic.byName( + DataSourceV2Relation.create(table, CaseInsensitiveStringMap.empty), TestRelation) + + V2WriteSupportCheck.apply(plan) + } +} + +private object V2WriteSupportCheckSuite { + val schema: StructType = new StructType().add("id", LongType).add("data", StringType) +} + +private case object TestRelation extends LeafNode with NamedRelation { + override def name: String = "source_relation" + override def output: Seq[AttributeReference] = V2WriteSupportCheckSuite.schema.toAttributes +} + +private case class CapabilityTable(_capabilities: TableCapability*) extends Table { + override def name(): String = "capability_test_table" + override def schema(): StructType = V2WriteSupportCheckSuite.schema + override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 13bb686fbd3b9..f022edea275e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -17,6 +17,9 @@ package org.apache.spark.sql.streaming.sources +import java.util +import java.util.Collections + import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} @@ -77,18 +80,21 @@ class FakeWriteBuilder extends WriteBuilder with StreamingWrite { trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeContinuousReadTable extends Table with SupportsContinuousRead { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder } trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new FakeWriteBuilder } @@ -137,6 +143,7 @@ class FakeReadNeitherMode extends DataSourceRegister with TableProvider { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() } } } @@ -164,6 +171,7 @@ class FakeNoWrite extends DataSourceRegister with TableProvider { new Table { override def name(): String = "fake" override def schema(): StructType = StructType(Nil) + override def capabilities(): util.Set[TableCapability] = Collections.emptySet() } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 132b0e4db0d71..68f4b2ddbac0b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ +import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -86,6 +87,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session override val extendedCheckRules: Seq[LogicalPlan => Unit] = PreWriteCheck +: PreReadCheck +: + V2WriteSupportCheck +: customCheckRules } From c46db75b4e64039d6a72a0262d81ac860139b4e8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 25 Mar 2019 17:43:03 -0700 Subject: [PATCH 19/70] [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. Currently in the grammar file the rule `query` is responsible to parse both select and insert statements. As a result, we need to have more semantic checks in the code to guard against in-valid insert constructs in a query. Couple of examples are in the `visitCreateView` and `visitAlterView` functions. One other issue is that, we don't catch the `invalid insert constructs` in all the places until checkAnalysis (the errors we raise can be confusing as well). Here are couple of examples : ```SQL select * from (insert into bar values (2)); ``` ``` Error in query: unresolved operator 'Project [*]; 'Project [*] +- SubqueryAlias `__auto_generated_subquery_name` +- InsertIntoHiveTable `default`.`bar`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [c1] +- Project [cast(col1#18 as int) AS c1#20] +- LocalRelation [col1#18] ``` ```SQL select * from foo where c1 in (insert into bar values (2)) ``` ``` Error in query: cannot resolve '(default.foo.`c1` IN (listquery()))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. Left side columns: [default.foo.`c1`]. Right side columns: [].;; 'Project [*] +- 'Filter c1#6 IN (list#5 []) : +- InsertIntoHiveTable `default`.`bar`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [c1] : +- Project [cast(col1#7 as int) AS c1#9] : +- LocalRelation [col1#7] +- SubqueryAlias `default`.`foo` +- HiveTableRelation `default`.`foo`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#6] ``` For both the cases above, we should reject the syntax at parser level. In this PR, we create two top-level parser rules to parse `SELECT` and `INSERT` respectively. I will create a small PR to allow CTEs in DESCRIBE QUERY after this PR is in. Added tests to PlanParserSuite and removed the semantic check tests from SparkSqlParserSuites. Closes #24150 from dilipbiswal/split-query-insert. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 23 ++++-- .../sql/catalyst/parser/AstBuilder.scala | 81 ++++++++++++++----- .../sql/catalyst/parser/PlanParserSuite.scala | 49 ++++++++++- .../spark/sql/execution/SparkSqlParser.scala | 17 ---- .../sql/execution/SparkSqlParserSuite.scala | 53 ++++++------ 5 files changed, 157 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 76a13c5e2478f..e527d186210d6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -76,6 +76,8 @@ singleTableSchema statement : query #statementDefault + | insertStatement #insertStatementDefault + | multiSelectStatement #multiSelectStatementDefault | USE db=identifier #use | CREATE DATABASE (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? @@ -343,9 +345,14 @@ resource : identifier STRING ; +insertStatement + : (ctes)? insertInto queryTerm queryOrganization #singleInsertQuery + | (ctes)? fromClause multiInsertQueryBody+ #multiInsertQuery + ; + queryNoWith - : insertInto? queryTerm queryOrganization #singleInsertQuery - | fromClause multiInsertQueryBody+ #multiInsertQuery + : queryTerm queryOrganization #noWithQuery + | fromClause selectStatement #queryWithFrom ; queryOrganization @@ -358,9 +365,15 @@ queryOrganization ; multiInsertQueryBody - : insertInto? - querySpecification - queryOrganization + : insertInto selectStatement + ; + +multiSelectStatement + : (ctes)? fromClause selectStatement+ #multiSelect + ; + +selectStatement + : querySpecification queryOrganization ; queryTerm diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index aa6d8cf7e5ad0..1eba982426117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -111,15 +111,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val query = plan(ctx.queryNoWith) // Apply CTEs - query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) - } - // Check for duplicate names. - checkDuplicateKeys(ctes, ctx) - With(query, ctes) + query.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) } + // Check for duplicate names. + checkDuplicateKeys(ctes, ctx) + With(plan, ctes) + } + + override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + } + + override def visitQueryWithFrom(ctx: QueryWithFromContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + validate(ctx.selectStatement.querySpecification.fromClause == null, + "Individual select statement can not have FROM cause as its already specified in the" + + " outer query block", ctx) + withQuerySpecification(ctx.selectStatement.querySpecification, from). + optionalMap(ctx.selectStatement.queryOrganization)(withQueryResultClauses) + } + + override def visitNoWithQuery(ctx: NoWithQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) } /** @@ -151,24 +170,49 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val from = visitFromClause(ctx.fromClause) // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody.asScala.map { + val inserts = ctx.multiInsertQueryBody().asScala.map { body => - validate(body.querySpecification.fromClause == null, + validate(body.selectStatement.querySpecification.fromClause == null, "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) + withInsertInto(body.insertInto, + withQuerySpecification(body.selectStatement.querySpecification, from). + // Add organization statements. + optionalMap(body.selectStatement.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + val insertPlan = inserts match { + case Seq(query) => query + case queries => Union(queries) + } + // Apply CTEs + insertPlan.optionalMap(ctx.ctes)(withCTE) + } + + override def visitMultiSelect(ctx: MultiSelectContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val selects = ctx.selectStatement.asScala.map { + body => + validate(body.querySpecification.fromClause == null, + "Multi-select queries cannot have a FROM clause in their individual SELECT statements", + body) + withQuerySpecification(body.querySpecification, from). // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(body.insertInto())(withInsertInto) + optionalMap(body.queryOrganization)(withQueryResultClauses) } // If there are multiple INSERTS just UNION them together into one query. - inserts match { + val selectUnionPlan = selects match { case Seq(query) => query case queries => Union(queries) } + // Apply CTEs + selectUnionPlan.optionalMap(ctx.ctes)(withCTE) } /** @@ -176,11 +220,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitSingleInsertQuery( ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm). - // Add organization statements. - optionalMap(ctx.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(ctx.insertInto())(withInsertInto) + val insertPlan = withInsertInto(ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + // Apply CTEs + insertPlan.optionalMap(ctx.ctes)(withCTE) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index f5da90f7cf0c6..5ef3b2b7615d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -132,7 +132,11 @@ class PlanParserSuite extends AnalysisTest { table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + "Multi-select queries cannot have a FROM clause in their individual SELECT statements") + intercept( + "from a select * from b", + "Individual select statement can not have FROM cause as its already specified in " + + "the outer query block") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( @@ -753,4 +757,47 @@ class PlanParserSuite extends AnalysisTest { assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) } } + + test("create/alter view as insert into table") { + val m1 = intercept[ParseException] { + parsePlan("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)") + }.getMessage + assert(m1.contains("mismatched input 'INSERT' expecting")) + // Multi insert query + val m2 = intercept[ParseException] { + parsePlan( + """ + |CREATE VIEW testView AS FROM jt + |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 + |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 + """.stripMargin) + }.getMessage + assert(m2.contains("mismatched input 'INSERT' expecting")) + val m3 = intercept[ParseException] { + parsePlan("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)") + }.getMessage + assert(m3.contains("mismatched input 'INSERT' expecting")) + // Multi insert query + val m4 = intercept[ParseException] { + parsePlan( + """ + |ALTER VIEW testView AS FROM jt + |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 + |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 + """.stripMargin + ) + }.getMessage + assert(m4.contains("mismatched input 'INSERT' expecting")) + } + + test("Invalid insert constructs in the query") { + val m1 = intercept[ParseException] { + parsePlan("SELECT * FROM (INSERT INTO BAR VALUES (2))") + }.getMessage + assert(m1.contains("mismatched input 'FROM' expecting")) + val m2 = intercept[ParseException] { + parsePlan("SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))") + }.getMessage + assert(m2.contains("mismatched input 'FROM' expecting")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8deb55b00a9d3..6757efd19b5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1431,15 +1431,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { - // CREATE VIEW ... AS INSERT INTO is not allowed. - ctx.query.queryNoWith match { - case s: SingleInsertQueryContext if s.insertInto != null => - operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) - case _: MultiInsertQueryContext => - operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) - case _ => // OK - } - val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) @@ -1476,14 +1467,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { - // ALTER VIEW ... AS INSERT INTO is not allowed. - ctx.query.queryNoWith match { - case s: SingleInsertQueryContext if s.insertInto != null => - operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) - case _: MultiInsertQueryContext => - operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) - case _ => // OK - } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 31b9bcdafbab8..a7a12cb6eebb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,30 +215,6 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } - test("create table using - schema") { - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", - createTableUsing( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - ) - ) - intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", - "no viable alternative at input") - } - - test("create view as insert into table") { - // Single insert query - intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", - "Operation not allowed: CREATE VIEW ... AS INSERT INTO") - - // Multi insert query - intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + - "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", - "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") - } - test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( @@ -377,6 +353,18 @@ class SparkSqlParserSuite extends AnalysisTest { Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } +<<<<<<< HEAD + test("SPARK-25046 Fix Alter View ... As Insert Into Table") { + // Single insert query + intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: ALTER VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") + } +||||||| parent of 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. test("SPARK-25046 Fix Alter View ... As Insert Into Table") { // Single insert query intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", @@ -387,4 +375,21 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } + + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } +======= + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } +>>>>>>> 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. } From bc6ece7dfe77e129425f773887992d4fe4106700 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 14:42:55 -0700 Subject: [PATCH 20/70] Revert "[SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file." This reverts commit c46db75b4e64039d6a72a0262d81ac860139b4e8. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 23 ++---- .../sql/catalyst/parser/AstBuilder.scala | 81 +++++-------------- .../sql/catalyst/parser/PlanParserSuite.scala | 49 +---------- .../spark/sql/execution/SparkSqlParser.scala | 17 ++++ .../sql/execution/SparkSqlParserSuite.scala | 53 ++++++------ 5 files changed, 66 insertions(+), 157 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e527d186210d6..76a13c5e2478f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -76,8 +76,6 @@ singleTableSchema statement : query #statementDefault - | insertStatement #insertStatementDefault - | multiSelectStatement #multiSelectStatementDefault | USE db=identifier #use | CREATE DATABASE (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? @@ -345,14 +343,9 @@ resource : identifier STRING ; -insertStatement - : (ctes)? insertInto queryTerm queryOrganization #singleInsertQuery - | (ctes)? fromClause multiInsertQueryBody+ #multiInsertQuery - ; - queryNoWith - : queryTerm queryOrganization #noWithQuery - | fromClause selectStatement #queryWithFrom + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery ; queryOrganization @@ -365,15 +358,9 @@ queryOrganization ; multiInsertQueryBody - : insertInto selectStatement - ; - -multiSelectStatement - : (ctes)? fromClause selectStatement+ #multiSelect - ; - -selectStatement - : querySpecification queryOrganization + : insertInto? + querySpecification + queryOrganization ; queryTerm diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1eba982426117..aa6d8cf7e5ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -111,34 +111,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val query = plan(ctx.queryNoWith) // Apply CTEs - query.optionalMap(ctx.ctes)(withCTE) - } - - private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { - val ctes = ctx.namedQuery.asScala.map { nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + checkDuplicateKeys(ctes, ctx) + With(query, ctes) } - // Check for duplicate names. - checkDuplicateKeys(ctes, ctx) - With(plan, ctes) - } - - override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) - } - - override def visitQueryWithFrom(ctx: QueryWithFromContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - validate(ctx.selectStatement.querySpecification.fromClause == null, - "Individual select statement can not have FROM cause as its already specified in the" + - " outer query block", ctx) - withQuerySpecification(ctx.selectStatement.querySpecification, from). - optionalMap(ctx.selectStatement.queryOrganization)(withQueryResultClauses) - } - - override def visitNoWithQuery(ctx: NoWithQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) } /** @@ -170,49 +151,24 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val from = visitFromClause(ctx.fromClause) // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody().asScala.map { - body => - validate(body.selectStatement.querySpecification.fromClause == null, - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", - body) - - withInsertInto(body.insertInto, - withQuerySpecification(body.selectStatement.querySpecification, from). - // Add organization statements. - optionalMap(body.selectStatement.queryOrganization)(withQueryResultClauses)) - } - - // If there are multiple INSERTS just UNION them together into one query. - val insertPlan = inserts match { - case Seq(query) => query - case queries => Union(queries) - } - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) - } - - override def visitMultiSelect(ctx: MultiSelectContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - - // Build the insert clauses. - val selects = ctx.selectStatement.asScala.map { + val inserts = ctx.multiInsertQueryBody.asScala.map { body => validate(body.querySpecification.fromClause == null, - "Multi-select queries cannot have a FROM clause in their individual SELECT statements", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) withQuerySpecification(body.querySpecification, from). // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses) + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) } // If there are multiple INSERTS just UNION them together into one query. - val selectUnionPlan = selects match { + inserts match { case Seq(query) => query case queries => Union(queries) } - // Apply CTEs - selectUnionPlan.optionalMap(ctx.ctes)(withCTE) } /** @@ -220,10 +176,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitSingleInsertQuery( ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - val insertPlan = withInsertInto(ctx.insertInto(), - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5ef3b2b7615d4..f5da90f7cf0c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -132,11 +132,7 @@ class PlanParserSuite extends AnalysisTest { table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", - "Multi-select queries cannot have a FROM clause in their individual SELECT statements") - intercept( - "from a select * from b", - "Individual select statement can not have FROM cause as its already specified in " + - "the outer query block") + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( @@ -757,47 +753,4 @@ class PlanParserSuite extends AnalysisTest { assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) } } - - test("create/alter view as insert into table") { - val m1 = intercept[ParseException] { - parsePlan("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)") - }.getMessage - assert(m1.contains("mismatched input 'INSERT' expecting")) - // Multi insert query - val m2 = intercept[ParseException] { - parsePlan( - """ - |CREATE VIEW testView AS FROM jt - |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 - |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 - """.stripMargin) - }.getMessage - assert(m2.contains("mismatched input 'INSERT' expecting")) - val m3 = intercept[ParseException] { - parsePlan("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)") - }.getMessage - assert(m3.contains("mismatched input 'INSERT' expecting")) - // Multi insert query - val m4 = intercept[ParseException] { - parsePlan( - """ - |ALTER VIEW testView AS FROM jt - |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 - |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 - """.stripMargin - ) - }.getMessage - assert(m4.contains("mismatched input 'INSERT' expecting")) - } - - test("Invalid insert constructs in the query") { - val m1 = intercept[ParseException] { - parsePlan("SELECT * FROM (INSERT INTO BAR VALUES (2))") - }.getMessage - assert(m1.contains("mismatched input 'FROM' expecting")) - val m2 = intercept[ParseException] { - parsePlan("SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))") - }.getMessage - assert(m2.contains("mismatched input 'FROM' expecting")) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 6757efd19b5a9..8deb55b00a9d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1431,6 +1431,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) @@ -1467,6 +1476,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + // ALTER VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a7a12cb6eebb4..31b9bcdafbab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,6 +215,30 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } + test("create table using - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", + createTableUsing( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") + } + + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( @@ -353,18 +377,6 @@ class SparkSqlParserSuite extends AnalysisTest { Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } -<<<<<<< HEAD - test("SPARK-25046 Fix Alter View ... As Insert Into Table") { - // Single insert query - intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", - "Operation not allowed: ALTER VIEW ... AS INSERT INTO") - - // Multi insert query - intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + - "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", - "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") - } -||||||| parent of 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. test("SPARK-25046 Fix Alter View ... As Insert Into Table") { // Single insert query intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", @@ -375,21 +387,4 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } - - test("database and schema tokens are interchangeable") { - assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) - assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) - assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", - parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) - assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) - } -======= - test("database and schema tokens are interchangeable") { - assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) - assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) - assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", - parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) - assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) - } ->>>>>>> 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. } From b9a206102f74527339632729c5d33b6ab5daa9f8 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Mon, 25 Mar 2019 17:43:03 -0700 Subject: [PATCH 21/70] [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. Currently in the grammar file the rule `query` is responsible to parse both select and insert statements. As a result, we need to have more semantic checks in the code to guard against in-valid insert constructs in a query. Couple of examples are in the `visitCreateView` and `visitAlterView` functions. One other issue is that, we don't catch the `invalid insert constructs` in all the places until checkAnalysis (the errors we raise can be confusing as well). Here are couple of examples : ```SQL select * from (insert into bar values (2)); ``` ``` Error in query: unresolved operator 'Project [*]; 'Project [*] +- SubqueryAlias `__auto_generated_subquery_name` +- InsertIntoHiveTable `default`.`bar`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [c1] +- Project [cast(col1#18 as int) AS c1#20] +- LocalRelation [col1#18] ``` ```SQL select * from foo where c1 in (insert into bar values (2)) ``` ``` Error in query: cannot resolve '(default.foo.`c1` IN (listquery()))' due to data type mismatch: The number of columns in the left hand side of an IN subquery does not match the number of columns in the output of subquery. Left side columns: [default.foo.`c1`]. Right side columns: [].;; 'Project [*] +- 'Filter c1#6 IN (list#5 []) : +- InsertIntoHiveTable `default`.`bar`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, false, false, [c1] : +- Project [cast(col1#7 as int) AS c1#9] : +- LocalRelation [col1#7] +- SubqueryAlias `default`.`foo` +- HiveTableRelation `default`.`foo`, org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe, [c1#6] ``` For both the cases above, we should reject the syntax at parser level. In this PR, we create two top-level parser rules to parse `SELECT` and `INSERT` respectively. I will create a small PR to allow CTEs in DESCRIBE QUERY after this PR is in. Added tests to PlanParserSuite and removed the semantic check tests from SparkSqlParserSuites. Closes #24150 from dilipbiswal/split-query-insert. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 23 ++++-- .../sql/catalyst/parser/AstBuilder.scala | 81 ++++++++++++++----- .../sql/catalyst/parser/PlanParserSuite.scala | 49 ++++++++++- .../spark/sql/execution/SparkSqlParser.scala | 17 ---- .../sql/execution/SparkSqlParserSuite.scala | 53 ++++++------ 5 files changed, 157 insertions(+), 66 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 76a13c5e2478f..e527d186210d6 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -76,6 +76,8 @@ singleTableSchema statement : query #statementDefault + | insertStatement #insertStatementDefault + | multiSelectStatement #multiSelectStatementDefault | USE db=identifier #use | CREATE DATABASE (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? @@ -343,9 +345,14 @@ resource : identifier STRING ; +insertStatement + : (ctes)? insertInto queryTerm queryOrganization #singleInsertQuery + | (ctes)? fromClause multiInsertQueryBody+ #multiInsertQuery + ; + queryNoWith - : insertInto? queryTerm queryOrganization #singleInsertQuery - | fromClause multiInsertQueryBody+ #multiInsertQuery + : queryTerm queryOrganization #noWithQuery + | fromClause selectStatement #queryWithFrom ; queryOrganization @@ -358,9 +365,15 @@ queryOrganization ; multiInsertQueryBody - : insertInto? - querySpecification - queryOrganization + : insertInto selectStatement + ; + +multiSelectStatement + : (ctes)? fromClause selectStatement+ #multiSelect + ; + +selectStatement + : querySpecification queryOrganization ; queryTerm diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index aa6d8cf7e5ad0..1eba982426117 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -111,15 +111,34 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val query = plan(ctx.queryNoWith) // Apply CTEs - query.optional(ctx.ctes) { - val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) - } - // Check for duplicate names. - checkDuplicateKeys(ctes, ctx) - With(query, ctes) + query.optionalMap(ctx.ctes)(withCTE) + } + + private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { + val ctes = ctx.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) } + // Check for duplicate names. + checkDuplicateKeys(ctes, ctx) + With(plan, ctes) + } + + override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + } + + override def visitQueryWithFrom(ctx: QueryWithFromContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + validate(ctx.selectStatement.querySpecification.fromClause == null, + "Individual select statement can not have FROM cause as its already specified in the" + + " outer query block", ctx) + withQuerySpecification(ctx.selectStatement.querySpecification, from). + optionalMap(ctx.selectStatement.queryOrganization)(withQueryResultClauses) + } + + override def visitNoWithQuery(ctx: NoWithQueryContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) } /** @@ -151,24 +170,49 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val from = visitFromClause(ctx.fromClause) // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody.asScala.map { + val inserts = ctx.multiInsertQueryBody().asScala.map { body => - validate(body.querySpecification.fromClause == null, + validate(body.selectStatement.querySpecification.fromClause == null, "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) + withInsertInto(body.insertInto, + withQuerySpecification(body.selectStatement.querySpecification, from). + // Add organization statements. + optionalMap(body.selectStatement.queryOrganization)(withQueryResultClauses)) + } + + // If there are multiple INSERTS just UNION them together into one query. + val insertPlan = inserts match { + case Seq(query) => query + case queries => Union(queries) + } + // Apply CTEs + insertPlan.optionalMap(ctx.ctes)(withCTE) + } + + override def visitMultiSelect(ctx: MultiSelectContext): LogicalPlan = withOrigin(ctx) { + val from = visitFromClause(ctx.fromClause) + + // Build the insert clauses. + val selects = ctx.selectStatement.asScala.map { + body => + validate(body.querySpecification.fromClause == null, + "Multi-select queries cannot have a FROM clause in their individual SELECT statements", + body) + withQuerySpecification(body.querySpecification, from). // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(body.insertInto())(withInsertInto) + optionalMap(body.queryOrganization)(withQueryResultClauses) } // If there are multiple INSERTS just UNION them together into one query. - inserts match { + val selectUnionPlan = selects match { case Seq(query) => query case queries => Union(queries) } + // Apply CTEs + selectUnionPlan.optionalMap(ctx.ctes)(withCTE) } /** @@ -176,11 +220,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitSingleInsertQuery( ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm). - // Add organization statements. - optionalMap(ctx.queryOrganization)(withQueryResultClauses). - // Add insert. - optionalMap(ctx.insertInto())(withInsertInto) + val insertPlan = withInsertInto(ctx.insertInto(), + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) + // Apply CTEs + insertPlan.optionalMap(ctx.ctes)(withCTE) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index f5da90f7cf0c6..5ef3b2b7615d4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -132,7 +132,11 @@ class PlanParserSuite extends AnalysisTest { table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") + "Multi-select queries cannot have a FROM clause in their individual SELECT statements") + intercept( + "from a select * from b", + "Individual select statement can not have FROM cause as its already specified in " + + "the outer query block") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( @@ -753,4 +757,47 @@ class PlanParserSuite extends AnalysisTest { assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) } } + + test("create/alter view as insert into table") { + val m1 = intercept[ParseException] { + parsePlan("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)") + }.getMessage + assert(m1.contains("mismatched input 'INSERT' expecting")) + // Multi insert query + val m2 = intercept[ParseException] { + parsePlan( + """ + |CREATE VIEW testView AS FROM jt + |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 + |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 + """.stripMargin) + }.getMessage + assert(m2.contains("mismatched input 'INSERT' expecting")) + val m3 = intercept[ParseException] { + parsePlan("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)") + }.getMessage + assert(m3.contains("mismatched input 'INSERT' expecting")) + // Multi insert query + val m4 = intercept[ParseException] { + parsePlan( + """ + |ALTER VIEW testView AS FROM jt + |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 + |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 + """.stripMargin + ) + }.getMessage + assert(m4.contains("mismatched input 'INSERT' expecting")) + } + + test("Invalid insert constructs in the query") { + val m1 = intercept[ParseException] { + parsePlan("SELECT * FROM (INSERT INTO BAR VALUES (2))") + }.getMessage + assert(m1.contains("mismatched input 'FROM' expecting")) + val m2 = intercept[ParseException] { + parsePlan("SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))") + }.getMessage + assert(m2.contains("mismatched input 'FROM' expecting")) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8deb55b00a9d3..6757efd19b5a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1431,15 +1431,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { - // CREATE VIEW ... AS INSERT INTO is not allowed. - ctx.query.queryNoWith match { - case s: SingleInsertQueryContext if s.insertInto != null => - operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) - case _: MultiInsertQueryContext => - operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) - case _ => // OK - } - val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) @@ -1476,14 +1467,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { - // ALTER VIEW ... AS INSERT INTO is not allowed. - ctx.query.queryNoWith match { - case s: SingleInsertQueryContext if s.insertInto != null => - operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) - case _: MultiInsertQueryContext => - operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) - case _ => // OK - } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 31b9bcdafbab8..a7a12cb6eebb4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,30 +215,6 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } - test("create table using - schema") { - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", - createTableUsing( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - ) - ) - intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", - "no viable alternative at input") - } - - test("create view as insert into table") { - // Single insert query - intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", - "Operation not allowed: CREATE VIEW ... AS INSERT INTO") - - // Multi insert query - intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + - "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", - "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") - } - test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( @@ -377,6 +353,18 @@ class SparkSqlParserSuite extends AnalysisTest { Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } +<<<<<<< HEAD + test("SPARK-25046 Fix Alter View ... As Insert Into Table") { + // Single insert query + intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: ALTER VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") + } +||||||| parent of 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. test("SPARK-25046 Fix Alter View ... As Insert Into Table") { // Single insert query intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", @@ -387,4 +375,21 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } + + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } +======= + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } +>>>>>>> 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. } From e68a36c9101e289e95d03e09a2969077918f9596 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 14:49:59 -0700 Subject: [PATCH 22/70] Revert "[SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file." This reverts commit b9a206102f74527339632729c5d33b6ab5daa9f8. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 23 ++---- .../sql/catalyst/parser/AstBuilder.scala | 81 +++++-------------- .../sql/catalyst/parser/PlanParserSuite.scala | 49 +---------- .../spark/sql/execution/SparkSqlParser.scala | 17 ++++ .../sql/execution/SparkSqlParserSuite.scala | 53 ++++++------ 5 files changed, 66 insertions(+), 157 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index e527d186210d6..76a13c5e2478f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -76,8 +76,6 @@ singleTableSchema statement : query #statementDefault - | insertStatement #insertStatementDefault - | multiSelectStatement #multiSelectStatementDefault | USE db=identifier #use | CREATE DATABASE (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? @@ -345,14 +343,9 @@ resource : identifier STRING ; -insertStatement - : (ctes)? insertInto queryTerm queryOrganization #singleInsertQuery - | (ctes)? fromClause multiInsertQueryBody+ #multiInsertQuery - ; - queryNoWith - : queryTerm queryOrganization #noWithQuery - | fromClause selectStatement #queryWithFrom + : insertInto? queryTerm queryOrganization #singleInsertQuery + | fromClause multiInsertQueryBody+ #multiInsertQuery ; queryOrganization @@ -365,15 +358,9 @@ queryOrganization ; multiInsertQueryBody - : insertInto selectStatement - ; - -multiSelectStatement - : (ctes)? fromClause selectStatement+ #multiSelect - ; - -selectStatement - : querySpecification queryOrganization + : insertInto? + querySpecification + queryOrganization ; queryTerm diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 1eba982426117..aa6d8cf7e5ad0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -111,34 +111,15 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val query = plan(ctx.queryNoWith) // Apply CTEs - query.optionalMap(ctx.ctes)(withCTE) - } - - private def withCTE(ctx: CtesContext, plan: LogicalPlan): LogicalPlan = { - val ctes = ctx.namedQuery.asScala.map { nCtx => - val namedQuery = visitNamedQuery(nCtx) - (namedQuery.alias, namedQuery) + query.optional(ctx.ctes) { + val ctes = ctx.ctes.namedQuery.asScala.map { nCtx => + val namedQuery = visitNamedQuery(nCtx) + (namedQuery.alias, namedQuery) + } + // Check for duplicate names. + checkDuplicateKeys(ctes, ctx) + With(query, ctes) } - // Check for duplicate names. - checkDuplicateKeys(ctes, ctx) - With(plan, ctes) - } - - override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) - } - - override def visitQueryWithFrom(ctx: QueryWithFromContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - validate(ctx.selectStatement.querySpecification.fromClause == null, - "Individual select statement can not have FROM cause as its already specified in the" + - " outer query block", ctx) - withQuerySpecification(ctx.selectStatement.querySpecification, from). - optionalMap(ctx.selectStatement.queryOrganization)(withQueryResultClauses) - } - - override def visitNoWithQuery(ctx: NoWithQueryContext): LogicalPlan = withOrigin(ctx) { - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) } /** @@ -170,49 +151,24 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val from = visitFromClause(ctx.fromClause) // Build the insert clauses. - val inserts = ctx.multiInsertQueryBody().asScala.map { - body => - validate(body.selectStatement.querySpecification.fromClause == null, - "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", - body) - - withInsertInto(body.insertInto, - withQuerySpecification(body.selectStatement.querySpecification, from). - // Add organization statements. - optionalMap(body.selectStatement.queryOrganization)(withQueryResultClauses)) - } - - // If there are multiple INSERTS just UNION them together into one query. - val insertPlan = inserts match { - case Seq(query) => query - case queries => Union(queries) - } - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) - } - - override def visitMultiSelect(ctx: MultiSelectContext): LogicalPlan = withOrigin(ctx) { - val from = visitFromClause(ctx.fromClause) - - // Build the insert clauses. - val selects = ctx.selectStatement.asScala.map { + val inserts = ctx.multiInsertQueryBody.asScala.map { body => validate(body.querySpecification.fromClause == null, - "Multi-select queries cannot have a FROM clause in their individual SELECT statements", + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements", body) withQuerySpecification(body.querySpecification, from). // Add organization statements. - optionalMap(body.queryOrganization)(withQueryResultClauses) + optionalMap(body.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(body.insertInto())(withInsertInto) } // If there are multiple INSERTS just UNION them together into one query. - val selectUnionPlan = selects match { + inserts match { case Seq(query) => query case queries => Union(queries) } - // Apply CTEs - selectUnionPlan.optionalMap(ctx.ctes)(withCTE) } /** @@ -220,10 +176,11 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging */ override def visitSingleInsertQuery( ctx: SingleInsertQueryContext): LogicalPlan = withOrigin(ctx) { - val insertPlan = withInsertInto(ctx.insertInto(), - plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses)) - // Apply CTEs - insertPlan.optionalMap(ctx.ctes)(withCTE) + plan(ctx.queryTerm). + // Add organization statements. + optionalMap(ctx.queryOrganization)(withQueryResultClauses). + // Add insert. + optionalMap(ctx.insertInto())(withInsertInto) } /** diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 5ef3b2b7615d4..f5da90f7cf0c6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -132,11 +132,7 @@ class PlanParserSuite extends AnalysisTest { table("a").select(star()).union(table("a").where('s < 10).select(star()))) intercept( "from a select * select * from x where a.s < 10", - "Multi-select queries cannot have a FROM clause in their individual SELECT statements") - intercept( - "from a select * from b", - "Individual select statement can not have FROM cause as its already specified in " + - "the outer query block") + "Multi-Insert queries cannot have a FROM clause in their individual SELECT statements") assertEqual( "from a insert into tbl1 select * insert into tbl2 select * where s < 10", table("a").select(star()).insertInto("tbl1").union( @@ -757,47 +753,4 @@ class PlanParserSuite extends AnalysisTest { assertEqual(query2, Distinct(a.union(b)).except(c.intersect(d, isAll = true), isAll = true)) } } - - test("create/alter view as insert into table") { - val m1 = intercept[ParseException] { - parsePlan("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)") - }.getMessage - assert(m1.contains("mismatched input 'INSERT' expecting")) - // Multi insert query - val m2 = intercept[ParseException] { - parsePlan( - """ - |CREATE VIEW testView AS FROM jt - |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 - |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 - """.stripMargin) - }.getMessage - assert(m2.contains("mismatched input 'INSERT' expecting")) - val m3 = intercept[ParseException] { - parsePlan("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)") - }.getMessage - assert(m3.contains("mismatched input 'INSERT' expecting")) - // Multi insert query - val m4 = intercept[ParseException] { - parsePlan( - """ - |ALTER VIEW testView AS FROM jt - |INSERT INTO tbl1 SELECT * WHERE jt.id < 5 - |INSERT INTO tbl2 SELECT * WHERE jt.id > 4 - """.stripMargin - ) - }.getMessage - assert(m4.contains("mismatched input 'INSERT' expecting")) - } - - test("Invalid insert constructs in the query") { - val m1 = intercept[ParseException] { - parsePlan("SELECT * FROM (INSERT INTO BAR VALUES (2))") - }.getMessage - assert(m1.contains("mismatched input 'FROM' expecting")) - val m2 = intercept[ParseException] { - parsePlan("SELECT * FROM S WHERE C1 IN (INSERT INTO T VALUES (2))") - }.getMessage - assert(m2.contains("mismatched input 'FROM' expecting")) - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 6757efd19b5a9..8deb55b00a9d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1431,6 +1431,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) @@ -1467,6 +1476,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitAlterViewQuery(ctx: AlterViewQueryContext): LogicalPlan = withOrigin(ctx) { + // ALTER VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("ALTER VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("ALTER VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } AlterViewAsCommand( name = visitTableIdentifier(ctx.tableIdentifier), originalText = source(ctx.query), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index a7a12cb6eebb4..31b9bcdafbab8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,6 +215,30 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } + test("create table using - schema") { + assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", + createTableUsing( + table = "my_tab", + schema = (new StructType) + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType) + ) + ) + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") + } + + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( @@ -353,18 +377,6 @@ class SparkSqlParserSuite extends AnalysisTest { Project(UnresolvedAlias(concat) :: Nil, UnresolvedRelation(TableIdentifier("t")))) } -<<<<<<< HEAD - test("SPARK-25046 Fix Alter View ... As Insert Into Table") { - // Single insert query - intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", - "Operation not allowed: ALTER VIEW ... AS INSERT INTO") - - // Multi insert query - intercept("ALTER VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + - "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", - "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") - } -||||||| parent of 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. test("SPARK-25046 Fix Alter View ... As Insert Into Table") { // Single insert query intercept("ALTER VIEW testView AS INSERT INTO jt VALUES(1, 1)", @@ -375,21 +387,4 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } - - test("database and schema tokens are interchangeable") { - assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) - assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) - assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", - parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) - assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) - } -======= - test("database and schema tokens are interchangeable") { - assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) - assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) - assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", - parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) - assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) - } ->>>>>>> 9cc925cda2... [SPARK-27209][SQL] Split parsing of SELECT and INSERT into two top-level rules in the grammar file. } From 942ac185c9222211fd1c6cd1ddad23b9e05ac266 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 23 Feb 2019 08:38:47 +0900 Subject: [PATCH 23/70] [SPARK-26215][SQL] Define reserved/non-reserved keywords based on the ANSI SQL standard ## What changes were proposed in this pull request? This pr targeted to define reserved/non-reserved keywords for Spark SQL based on the ANSI SQL standards and the other database-like systems (e.g., PostgreSQL). We assume that they basically follow the ANSI SQL-2011 standard, but it is slightly different between each other. Therefore, this pr documented all the keywords in `docs/sql-reserved-and-non-reserved-key-words.md`. NOTE: This pr only added a small set of keywords as reserved ones and these keywords are reserved in all the ANSI SQL standards (SQL-92, SQL-99, SQL-2003, SQL-2008, SQL-2011, and SQL-2016) and PostgreSQL. This is because there is room to discuss which keyword should be reserved or not, .e.g., interval units (day, hour, minute, second, ...) are reserved in the ANSI SQL standards though, they are not reserved in PostgreSQL. Therefore, we need more researches about the other database-like systems (e.g., Oracle Databases, DB2, SQL server) in follow-up activities. References: - The reserved/non-reserved SQL keywords in the ANSI SQL standards: https://developer.mimer.com/wp-content/uploads/2018/05/Standard-SQL-Reserved-Words-Summary.pdf - SQL Key Words in PostgreSQL: https://www.postgresql.org/docs/current/sql-keywords-appendix.html ## How was this patch tested? Added tests in `TableIdentifierParserSuite`. Closes #23259 from maropu/SPARK-26215-WIP. Authored-by: Takeshi Yamamuro Signed-off-by: Takeshi Yamamuro --- docs/_data/menu-sql.yaml | 2 + .../sql-reserved-and-non-reserved-keywords.md | 574 ++++++++++++++++++ .../spark/sql/catalyst/parser/SqlBase.g4 | 127 ++-- .../sql/catalyst/parser/ParseDriver.scala | 2 + .../apache/spark/sql/internal/SQLConf.scala | 8 + 5 files changed, 675 insertions(+), 38 deletions(-) create mode 100644 docs/sql-reserved-and-non-reserved-keywords.md diff --git a/docs/_data/menu-sql.yaml b/docs/_data/menu-sql.yaml index cd065ea01dda4..9bbb115bcdda5 100644 --- a/docs/_data/menu-sql.yaml +++ b/docs/_data/menu-sql.yaml @@ -70,6 +70,8 @@ url: sql-migration-guide-upgrade.html - text: Compatibility with Apache Hive url: sql-migration-guide-hive-compatibility.html + - text: SQL Reserved/Non-Reserved Keywords + url: sql-reserved-and-non-reserved-keywords.html - text: Reference url: sql-reference.html subitems: diff --git a/docs/sql-reserved-and-non-reserved-keywords.md b/docs/sql-reserved-and-non-reserved-keywords.md new file mode 100644 index 0000000000000..321fb3f00acbd --- /dev/null +++ b/docs/sql-reserved-and-non-reserved-keywords.md @@ -0,0 +1,574 @@ +--- +layout: global +title: SQL Reserved/Non-Reserved Keywords +displayTitle: SQL Reserved/Non-Reserved Keywords +--- + +In Spark SQL, there are 2 kinds of keywords: non-reserved and reserved. Non-reserved keywords have a +special meaning only in particular contexts and can be used as identifiers (e.g., table names, view names, +column names, column aliases, table aliases) in other contexts. Reserved keywords can't be used as +table alias, but can be used as other identifiers. + +The list of reserved and non-reserved keywords can change according to the config +`spark.sql.parser.ansi.enabled`, which is false by default. + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
KeywordSpark SQLSQL-2011
ANSI modedefault mode
ABSnon-reservednon-reservedreserved
ABSOLUTEnon-reservednon-reservednon-reserved
ACOSnon-reservednon-reservednon-reserved
ACTIONnon-reservednon-reservednon-reserved
ADDnon-reservednon-reservednon-reserved
AFTERnon-reservednon-reservednon-reserved
ALLreservednon-reservedreserved
ALLOCATEnon-reservednon-reservedreserved
ALTERnon-reservednon-reservedreserved
ANALYZEnon-reservednon-reservednon-reserved
ANDreservednon-reservedreserved
ANTIreservedreservednon-reserved
ANYreservednon-reservedreserved
AREnon-reservednon-reservedreserved
ARCHIVEnon-reservednon-reservednon-reserved
ARRAYnon-reservednon-reservedreserved
ARRAY_AGGnon-reservednon-reservedreserved
ARRAY_MAX_CARDINALITYnon-reservednon-reservedreserved
ASreservednon-reservedreserved
ASCnon-reservednon-reservednon-reserved
ASENSITIVEnon-reservednon-reservedreserved
ASINnon-reservednon-reservedreserved
ASSERTIONnon-reservednon-reservednon-reserved
ASYMMETRICnon-reservednon-reservedreserved
ATnon-reservednon-reservedreserved
ATANnon-reservednon-reservednon-reserved
ATOMICnon-reservednon-reservedreserved
AUTHORIZATIONreservednon-reservedreserved
AVGnon-reservednon-reservedreserved
BEFOREnon-reservednon-reservednon-reserved
BEGINnon-reservednon-reservedreserved
BEGIN_FRAMEnon-reservednon-reservedreserved
BEGIN_PARTITIONnon-reservednon-reservedreserved
BETWEENnon-reservednon-reservedreserved
BIGINTnon-reservednon-reservedreserved
BINARYnon-reservednon-reservedreserved
BITnon-reservednon-reservednon-reserved
BIT_LENGTHnon-reservednon-reservednon-reserved
BLOBnon-reservednon-reservedreserved
BOOLEANnon-reservednon-reservedreserved
BOTHreservednon-reservedreserved
BREADTHnon-reservednon-reservednon-reserved
BUCKETnon-reservednon-reservednon-reserved
BUCKETSnon-reservednon-reservednon-reserved
BYnon-reservednon-reservedreserved
CACHEnon-reservednon-reservednon-reserved
CALLnon-reservednon-reservedreserved
CALLEDnon-reservednon-reservedreserved
CARDINALITYnon-reservednon-reservedreserved
CASCADEnon-reservednon-reservedreserved
CASCADEDnon-reservednon-reservedreserved
CASEreservednon-reservedreserved
CASTreservednon-reservedreserved
CATALOGnon-reservednon-reservednon-reserved
CEILnon-reservednon-reservedreserved
CEILINGnon-reservednon-reservedreserved
CHANGEnon-reservednon-reservednon-reserved
CHARnon-reservednon-reservedreserved
CHAR_LENGTHnon-reservednon-reservedreserved
CHARACTERnon-reservednon-reservedreserved
CHARACTER_LENGTHnon-reservednon-reservedreserved
CHECKreservednon-reservedreserved
CLASSIFIERnon-reservednon-reservednon-reserved
CLEARnon-reservednon-reservednon-reserved
CLOBnon-reservednon-reservedreserved
CLOSEnon-reservednon-reservedreserved
CLUSTERnon-reservednon-reservednon-reserved
CLUSTEREDnon-reservednon-reservednon-reserved
COALESCEnon-reservednon-reservedreserved
CODEGENnon-reservednon-reservednon-reserved
COLLATEreservednon-reservedreserved
COLLATIONnon-reservednon-reservednon-reserved
COLLECTnon-reservednon-reservedreserved
COLLECTIONnon-reservednon-reservednon-reserved
COLUMNreservednon-reservedreserved
COLUMNSnon-reservednon-reservednon-reserved
COMMENTnon-reservednon-reservednon-reserved
COMMITnon-reservednon-reservedreserved
COMPACTnon-reservednon-reservednon-reserved
COMPACTIONSnon-reservednon-reservednon-reserved
COMPUTEnon-reservednon-reservednon-reserved
CONCATENATEnon-reservednon-reservednon-reserved
CONDITIONnon-reservednon-reservedreserved
CONNECTnon-reservednon-reservednon-reserved
CONNECTIONnon-reservednon-reservednon-reserved
CONSTRAINTreservednon-reservedreserved
CONSTRAINTSnon-reservednon-reservednon-reserved
CONSTRUCTORnon-reservednon-reservednon-reserved
CONTAINSnon-reservednon-reservednon-reserved
CONTINUEnon-reservednon-reservednon-reserved
CONVERTnon-reservednon-reservedreserved
COPYnon-reservednon-reservednon-reserved
CORRnon-reservednon-reservedreserved
CORRESPONDINGnon-reservednon-reservedreserved
COSnon-reservednon-reservednon-reserved
COSHnon-reservednon-reservednon-reserved
COSTnon-reservednon-reservednon-reserved
COUNTnon-reservednon-reservedreserved
COVAR_POPnon-reservednon-reservedreserved
COVAR_SAMPnon-reservednon-reservedreserved
CREATEreservednon-reservedreserved
CROSSreservedreservedreserved
CUBEnon-reservednon-reservedreserved
CUME_DISTnon-reservednon-reservedreserved
CURRENTnon-reservednon-reservedreserved
CURRENT_CATALOGnon-reservednon-reservedreserved
CURRENT_DATEreservednon-reservedreserved
CURRENT_DEFAULT_TRANSFORM_GROUPnon-reservednon-reservedreserved
CURRENT_PATHnon-reservednon-reservedreserved
CURRENT_ROLEnon-reservednon-reservedreserved
CURRENT_ROWnon-reservednon-reservedreserved
CURRENT_SCHEMAnon-reservednon-reservedreserved
CURRENT_TIMEreservednon-reservedreserved
CURRENT_TIMESTAMPreservednon-reservedreserved
CURRENT_TRANSFORM_GROUP_FOR_TYPEnon-reservednon-reservedreserved
CURRENT_USERreservednon-reservedreserved
CURSORnon-reservednon-reservedreserved
CYCLEnon-reservednon-reservedreserved
DATAnon-reservednon-reservednon-reserved
DATABASEnon-reservednon-reservednon-reserved
DATABASESnon-reservednon-reservednon-reserved
DATEnon-reservednon-reservedreserved
DAYnon-reservednon-reservedreserved
DBPROPERTIESnon-reservednon-reservednon-reserved
DEALLOCATEnon-reservednon-reservedreserved
DECnon-reservednon-reservedreserved
DECFLOATnon-reservednon-reservednon-reserved
DECIMALnon-reservednon-reservedreserved
DECLAREnon-reservednon-reservedreserved
DEFAULTnon-reservednon-reservedreserved
DEFERRABLEnon-reservednon-reservednon-reserved
DEFERREDnon-reservednon-reservednon-reserved
DEFINEnon-reservednon-reservednon-reserved
DEFINEDnon-reservednon-reservednon-reserved
DELETEnon-reservednon-reservedreserved
DELIMITEDnon-reservednon-reservednon-reserved
DENSE_RANKnon-reservednon-reservedreserved
DEPTHnon-reservednon-reservednon-reserved
DEREFnon-reservednon-reservedreserved
DESCnon-reservednon-reservednon-reserved
DESCRIBEnon-reservednon-reservedreserved
DESCRIPTORnon-reservednon-reservednon-reserved
DETERMINISTICnon-reservednon-reservedreserved
DFSnon-reservednon-reservednon-reserved
DIAGNOSTICSnon-reservednon-reservednon-reserved
DIRECTORIESnon-reservednon-reservednon-reserved
DIRECTORYnon-reservednon-reservednon-reserved
DISCONNECTnon-reservednon-reservedreserved
DISTINCTreservednon-reservedreserved
DISTRIBUTEnon-reservednon-reservednon-reserved
DIVnon-reservednon-reservednon-reserved
DOnon-reservednon-reservedreserved
DOMAINnon-reservednon-reservednon-reserved
DOUBLEnon-reservednon-reservedreserved
DROPnon-reservednon-reservedreserved
DYNAMICnon-reservednon-reservedreserved
EACHnon-reservednon-reservedreserved
ELEMENTnon-reservednon-reservedreserved
ELSEreservednon-reservedreserved
ELSEIFnon-reservednon-reservedreserved
EMPTYnon-reservednon-reservednon-reserved
ENDreservednon-reservedreserved
END_FRAMEnon-reservednon-reservedreserved
END_PARTITIONnon-reservednon-reservedreserved
EQUALSnon-reservednon-reservednon-reserved
ESCAPEnon-reservednon-reservedreserved
ESCAPEDnon-reservednon-reservednon-reserved
EVERYnon-reservednon-reservedreserved
EXCEPTreservedreservedreserved
EXCEPTIONnon-reservednon-reservednon-reserved
EXCHANGEnon-reservednon-reservednon-reserved
EXECnon-reservednon-reservedreserved
EXECUTEnon-reservednon-reservedreserved
EXISTSnon-reservednon-reservedreserved
EXITnon-reservednon-reservednon-reserved
EXPnon-reservednon-reservednon-reserved
EXPLAINnon-reservednon-reservednon-reserved
EXPORTnon-reservednon-reservednon-reserved
EXTENDEDnon-reservednon-reservednon-reserved
EXTERNALnon-reservednon-reservedreserved
EXTRACTnon-reservednon-reservedreserved
FALSEreservednon-reservedreserved
FETCHreservednon-reservedreserved
FIELDSnon-reservednon-reservednon-reserved
FILEFORMATnon-reservednon-reservednon-reserved
FILTERnon-reservednon-reservedreserved
FIRSTnon-reservednon-reservednon-reserved
FIRST_VALUEnon-reservednon-reservedreserved
FLOATnon-reservednon-reservedreserved
FOLLOWINGnon-reservednon-reservednon-reserved
FORreservednon-reservedreserved
FOREIGNreservednon-reservedreserved
FORMATnon-reservednon-reservednon-reserved
FORMATTEDnon-reservednon-reservednon-reserved
FOUNDnon-reservednon-reservednon-reserved
FRAME_ROWnon-reservednon-reservedreserved
FREEnon-reservednon-reservedreserved
FROMreservednon-reservedreserved
FULLreservedreservedreserved
FUNCTIONnon-reservednon-reservedreserved
FUNCTIONSnon-reservednon-reservednon-reserved
FUSIONnon-reservednon-reservednon-reserved
GENERALnon-reservednon-reservednon-reserved
GETnon-reservednon-reservedreserved
GLOBALnon-reservednon-reservedreserved
GOnon-reservednon-reservednon-reserved
GOTOnon-reservednon-reservednon-reserved
GRANTreservednon-reservedreserved
GROUPreservednon-reservedreserved
GROUPINGnon-reservednon-reservedreserved
GROUPSnon-reservednon-reservedreserved
HANDLERnon-reservednon-reservedreserved
HAVINGreservednon-reservedreserved
HOLDnon-reservednon-reservedreserved
HOURnon-reservednon-reservedreserved
IDENTITYnon-reservednon-reservedreserved
IFnon-reservednon-reservedreserved
IGNOREnon-reservednon-reservednon-reserved
IMMEDIATEnon-reservednon-reservednon-reserved
IMPORTnon-reservednon-reservednon-reserved
INreservednon-reservedreserved
INDICATORnon-reservednon-reservedreserved
INDEXnon-reservednon-reservednon-reserved
INDEXESnon-reservednon-reservednon-reserved
INITIALnon-reservednon-reservednon-reserved
INITIALLYnon-reservednon-reservednon-reserved
INNERreservedreservedreserved
INOUTnon-reservednon-reservedreserved
INPATHnon-reservednon-reservednon-reserved
INPUTnon-reservednon-reservednon-reserved
INPUTFORMATnon-reservednon-reservednon-reserved
INSENSITIVEnon-reservednon-reservedreserved
INSERTnon-reservednon-reservedreserved
INTnon-reservednon-reservedreserved
INTEGERnon-reservednon-reservedreserved
INTERSECTreservedreservedreserved
INTERSECTIONnon-reservednon-reservedreserved
INTERVALnon-reservednon-reservedreserved
INTOreservednon-reservedreserved
ISreservednon-reservedreserved
ISOLATIONnon-reservednon-reservednon-reserved
ITEMSnon-reservednon-reservednon-reserved
ITERATEnon-reservednon-reservedreserved
JOINreservedreservedreserved
JSON_ARRAYnon-reservednon-reservednon-reserved
JSON_ARRAYAGGnon-reservednon-reservednon-reserved
JSON_EXISTSnon-reservednon-reservednon-reserved
JSON_OBJECTnon-reservednon-reservednon-reserved
JSON_OBJECTAGGnon-reservednon-reservednon-reserved
JSON_QUERYnon-reservednon-reservednon-reserved
JSON_TABLEnon-reservednon-reservednon-reserved
JSON_TABLE_PRIMITIVEnon-reservednon-reservednon-reserved
JSON_VALUEnon-reservednon-reservednon-reserved
KEYnon-reservednon-reservednon-reserved
KEYSnon-reservednon-reservednon-reserved
LAGnon-reservednon-reservednon-reserved
LANGUAGEnon-reservednon-reservedreserved
LARGEnon-reservednon-reservedreserved
LASTnon-reservednon-reservednon-reserved
LAST_VALUEnon-reservednon-reservedreserved
LATERALnon-reservednon-reservedreserved
LAZYnon-reservednon-reservednon-reserved
LEADnon-reservednon-reservedreserved
LEADINGreservednon-reservedreserved
LEAVEnon-reservednon-reservedreserved
LEFTreservedreservedreserved
LEVELnon-reservednon-reservednon-reserved
LIKEnon-reservednon-reservedreserved
LIKE_REGEXnon-reservednon-reservedreserved
LIMITnon-reservednon-reservednon-reserved
LINESnon-reservednon-reservednon-reserved
LISTnon-reservednon-reservednon-reserved
LISTAGGnon-reservednon-reservednon-reserved
LNnon-reservednon-reservedreserved
LOADnon-reservednon-reservednon-reserved
LOCALnon-reservednon-reservedreserved
LOCALTIMEnon-reservednon-reservedreserved
LOCALTIMESTAMPnon-reservednon-reservedreserved
LOCATIONnon-reservednon-reservednon-reserved
LOCATORnon-reservednon-reservednon-reserved
LOCKnon-reservednon-reservednon-reserved
LOCKSnon-reservednon-reservednon-reserved
LOGnon-reservednon-reservednon-reserved
LOG10non-reservednon-reservednon-reserved
LOGICALnon-reservednon-reservednon-reserved
LOOPnon-reservednon-reservedreserved
LOWERnon-reservednon-reservedreserved
MACROnon-reservednon-reservednon-reserved
MAPnon-reservednon-reservednon-reserved
MATCHnon-reservednon-reservedreserved
MATCH_NUMBERnon-reservednon-reservednon-reserved
MATCH_RECOGNIZEnon-reservednon-reservednon-reserved
MATCHESnon-reservednon-reservednon-reserved
MAXnon-reservednon-reservedreserved
MEMBERnon-reservednon-reservedreserved
MERGEnon-reservednon-reservedreserved
METHODnon-reservednon-reservedreserved
MINnon-reservednon-reservedreserved
MINUSreservedreservednon-reserved
MINUTEnon-reservednon-reservedreserved
MODnon-reservednon-reservedreserved
MODIFIESnon-reservednon-reservedreserved
MODULEnon-reservednon-reservedreserved
MONTHnon-reservednon-reservedreserved
MSCKnon-reservednon-reservednon-reserved
MULTISETnon-reservednon-reservedreserved
NAMESnon-reservednon-reservednon-reserved
NATIONALnon-reservednon-reservedreserved
NATURALreservedreservedreserved
NCHARnon-reservednon-reservedreserved
NCLOBnon-reservednon-reservedreserved
NEWnon-reservednon-reservedreserved
NEXTnon-reservednon-reservednon-reserved
NOnon-reservednon-reservedreserved
NONEnon-reservednon-reservedreserved
NORMALIZEnon-reservednon-reservedreserved
NOTreservednon-reservedreserved
NTH_VALUEnon-reservednon-reservedreserved
NTILEnon-reservednon-reservedreserved
NULLreservednon-reservedreserved
NULLSnon-reservednon-reservednon-reserved
NULLIFnon-reservednon-reservedreserved
NUMERICnon-reservednon-reservedreserved
OBJECTnon-reservednon-reservednon-reserved
OCCURRENCES_REGEXnon-reservednon-reservednon-reserved
OCTET_LENGTHnon-reservednon-reservedreserved
OFnon-reservednon-reservedreserved
OFFSETnon-reservednon-reservedreserved
OLDnon-reservednon-reservedreserved
OMITnon-reservednon-reservednon-reserved
ONreservedreservedreserved
ONEnon-reservednon-reservednon-reserved
ONLYreservednon-reservedreserved
OPENnon-reservednon-reservedreserved
OPTIONnon-reservednon-reservednon-reserved
OPTIONSnon-reservednon-reservednon-reserved
ORreservednon-reservedreserved
ORDERreservednon-reservedreserved
ORDINALITYnon-reservednon-reservednon-reserved
OUTnon-reservednon-reservedreserved
OUTERreservednon-reservedreserved
OUTPUTnon-reservednon-reservednon-reserved
OUTPUTFORMATnon-reservednon-reservednon-reserved
OVERnon-reservednon-reservednon-reserved
OVERLAPSreservednon-reservedreserved
OVERLAYnon-reservednon-reservedreserved
OVERWRITEnon-reservednon-reservednon-reserved
PADnon-reservednon-reservednon-reserved
PARAMETERnon-reservednon-reservedreserved
PARTIALnon-reservednon-reservednon-reserved
PARTITIONnon-reservednon-reservedreserved
PARTITIONEDnon-reservednon-reservednon-reserved
PARTITIONSnon-reservednon-reservednon-reserved
PATHnon-reservednon-reservednon-reserved
PATTERNnon-reservednon-reservednon-reserved
PERnon-reservednon-reservednon-reserved
PERCENTnon-reservednon-reservedreserved
PERCENT_RANKnon-reservednon-reservedreserved
PERCENTILE_CONTnon-reservednon-reservedreserved
PERCENTILE_DISCnon-reservednon-reservedreserved
PERCENTLITnon-reservednon-reservednon-reserved
PERIODnon-reservednon-reservedreserved
PIVOTnon-reservednon-reservednon-reserved
PORTIONnon-reservednon-reservedreserved
POSITIONnon-reservednon-reservedreserved
POSITION_REGEXnon-reservednon-reservedreserved
POWERnon-reservednon-reservedreserved
PRECEDESnon-reservednon-reservedreserved
PRECEDINGnon-reservednon-reservednon-reserved
PRECISIONnon-reservednon-reservedreserved
PREPAREnon-reservednon-reservedreserved
PRESERVEnon-reservednon-reservednon-reserved
PRIMARYreservednon-reservedreserved
PRINCIPALSnon-reservednon-reservednon-reserved
PRIORnon-reservednon-reservednon-reserved
PRIVILEGESnon-reservednon-reservednon-reserved
PROCEDUREnon-reservednon-reservedreserved
PTFnon-reservednon-reservednon-reserved
PUBLICnon-reservednon-reservednon-reserved
PURGEnon-reservednon-reservednon-reserved
RANGEnon-reservednon-reservedreserved
RANKnon-reservednon-reservedreserved
READnon-reservednon-reservednon-reserved
READSnon-reservednon-reservedreserved
REALnon-reservednon-reservedreserved
RECORDREADERnon-reservednon-reservednon-reserved
RECORDWRITERnon-reservednon-reservednon-reserved
RECURSIVEnon-reservednon-reservedreserved
RECOVERnon-reservednon-reservednon-reserved
REDUCEnon-reservednon-reservednon-reserved
REFnon-reservednon-reservedreserved
REFERENCESreservednon-reservedreserved
REFERENCINGnon-reservednon-reservedreserved
REFRESHnon-reservednon-reservednon-reserved
REGR_AVGXnon-reservednon-reservedreserved
REGR_AVGYnon-reservednon-reservedreserved
REGR_COUNTnon-reservednon-reservedreserved
REGR_INTERCEPTnon-reservednon-reservedreserved
REGR_R2non-reservednon-reservedreserved
REGR_SLOPEnon-reservednon-reservedreserved
REGR_SXXnon-reservednon-reservedreserved
REGR_SXYnon-reservednon-reservedreserved
REGR_SYYnon-reservednon-reservedreserved
RELATIVEnon-reservednon-reservednon-reserved
RELEASEnon-reservednon-reservedreserved
RENAMEnon-reservednon-reservednon-reserved
REPAIRnon-reservednon-reservednon-reserved
REPEATnon-reservednon-reservedreserved
REPLACEnon-reservednon-reservednon-reserved
RESETnon-reservednon-reservednon-reserved
RESIGNALnon-reservednon-reservedreserved
RESTRICTnon-reservednon-reservednon-reserved
RESULTnon-reservednon-reservedreserved
RETURNnon-reservednon-reservedreserved
RETURNSnon-reservednon-reservedreserved
REVOKEnon-reservednon-reservedreserved
RIGHTreservedreservedreserved
RLIKEnon-reservednon-reservednon-reserved
ROLEnon-reservednon-reservednon-reserved
ROLESnon-reservednon-reservednon-reserved
ROLLBACKnon-reservednon-reservedreserved
ROLLUPnon-reservednon-reservedreserved
ROUTINEnon-reservednon-reservednon-reserved
ROWnon-reservednon-reservedreserved
ROW_NUMBERnon-reservednon-reservedreserved
ROWSnon-reservednon-reservedreserved
RUNNINGnon-reservednon-reservednon-reserved
SAVEPOINTnon-reservednon-reservedreserved
SCHEMAnon-reservednon-reservednon-reserved
SCOPEnon-reservednon-reservedreserved
SCROLLnon-reservednon-reservedreserved
SEARCHnon-reservednon-reservedreserved
SECONDnon-reservednon-reservedreserved
SECTIONnon-reservednon-reservednon-reserved
SEEKnon-reservednon-reservednon-reserved
SELECTreservednon-reservedreserved
SEMIreservedreservednon-reserved
SENSITIVEnon-reservednon-reservedreserved
SEPARATEDnon-reservednon-reservednon-reserved
SERDEnon-reservednon-reservednon-reserved
SERDEPROPERTIESnon-reservednon-reservednon-reserved
SESSIONnon-reservednon-reservednon-reserved
SESSION_USERreservednon-reservedreserved
SETnon-reservednon-reservedreserved
SETSnon-reservednon-reservednon-reserved
SHOWnon-reservednon-reservednon-reserved
SIGNALnon-reservednon-reservedreserved
SIMILARnon-reservednon-reservedreserved
SINnon-reservednon-reservednon-reserved
SINHnon-reservednon-reservednon-reserved
SIZEnon-reservednon-reservednon-reserved
SKIPnon-reservednon-reservednon-reserved
SKEWEDnon-reservednon-reservednon-reserved
SMALLINTnon-reservednon-reservedreserved
SOMEreservednon-reservedreserved
SORTnon-reservednon-reservednon-reserved
SORTEDnon-reservednon-reservednon-reserved
SPACEnon-reservednon-reservednon-reserved
SPECIFICnon-reservednon-reservedreserved
SPECIFICTYPEnon-reservednon-reservedreserved
SQLnon-reservednon-reservedreserved
SQLCODEnon-reservednon-reservednon-reserved
SQLERRORnon-reservednon-reservednon-reserved
SQLEXCEPTIONnon-reservednon-reservedreserved
SQLSTATEnon-reservednon-reservedreserved
SQLWARNINGnon-reservednon-reservedreserved
SQRTnon-reservednon-reservedreserved
STARTnon-reservednon-reservedreserved
STATEnon-reservednon-reservednon-reserved
STATICnon-reservednon-reservedreserved
STATISTICSnon-reservednon-reservednon-reserved
STDDEV_POPnon-reservednon-reservedreserved
STDDEV_SAMPnon-reservednon-reservedreserved
STOREDnon-reservednon-reservednon-reserved
STRATIFYnon-reservednon-reservednon-reserved
STRUCTnon-reservednon-reservednon-reserved
SUBMULTISETnon-reservednon-reservedreserved
SUBSETnon-reservednon-reservednon-reserved
SUBSTRINGnon-reservednon-reservedreserved
SUBSTRING_REGEXnon-reservednon-reservedreserved
SUCCEEDSnon-reservednon-reservedreserved
SUMnon-reservednon-reservedreserved
SYMMETRICnon-reservednon-reservedreserved
SYSTEMnon-reservednon-reservedreserved
SYSTEM_TIMEnon-reservednon-reservedreserved
SYSTEM_USERnon-reservednon-reservedreserved
TABLEreservednon-reservedreserved
TABLESnon-reservednon-reservednon-reserved
TABLESAMPLEnon-reservednon-reservedreserved
TANnon-reservednon-reservednon-reserved
TANHnon-reservednon-reservednon-reserved
TBLPROPERTIESnon-reservednon-reservednon-reserved
TEMPORARYnon-reservednon-reservednon-reserved
TERMINATEDnon-reservednon-reservednon-reserved
THENreservednon-reservedreserved
TIMEnon-reservednon-reservedreserved
TIMESTAMPnon-reservednon-reservedreserved
TIMEZONE_HOURnon-reservednon-reservedreserved
TIMEZONE_MINUTEnon-reservednon-reservedreserved
TOreservednon-reservedreserved
TOUCHnon-reservednon-reservednon-reserved
TRAILINGreservednon-reservedreserved
TRANSACTIONnon-reservednon-reservednon-reserved
TRANSACTIONSnon-reservednon-reservednon-reserved
TRANSFORMnon-reservednon-reservednon-reserved
TRANSLATEnon-reservednon-reservedreserved
TRANSLATE_REGEXnon-reservednon-reservedreserved
TRANSLATIONnon-reservednon-reservedreserved
TREATnon-reservednon-reservedreserved
TRIGGERnon-reservednon-reservedreserved
TRIMnon-reservednon-reservedreserved
TRIM_ARRAYnon-reservednon-reservedreserved
TRUEnon-reservednon-reservedreserved
TRUNCATEnon-reservednon-reservedreserved
UESCAPEnon-reservednon-reservedreserved
UNARCHIVEnon-reservednon-reservednon-reserved
UNBOUNDEDnon-reservednon-reservednon-reserved
UNCACHEnon-reservednon-reservednon-reserved
UNDERnon-reservednon-reservednon-reserved
UNDOnon-reservednon-reservedreserved
UNIONreservedreservedreserved
UNIQUEreservednon-reservedreserved
UNKNOWNnon-reservednon-reservedreserved
UNLOCKnon-reservednon-reservednon-reserved
UNNESTnon-reservednon-reservedreserved
UNSETnon-reservednon-reservednon-reserved
UNTILnon-reservednon-reservedreserved
UPDATEnon-reservednon-reservedreserved
UPPERnon-reservednon-reservedreserved
USAGEnon-reservednon-reservednon-reserved
USEnon-reservednon-reservednon-reserved
USERreservednon-reservedreserved
USINGreservedreservedreserved
VALUEnon-reservednon-reservedreserved
VALUESnon-reservednon-reservedreserved
VALUE_OFnon-reservednon-reservedreserved
VAR_POPnon-reservednon-reservedreserved
VAR_SAMPnon-reservednon-reservedreserved
VARBINARYnon-reservednon-reservedreserved
VARCHARnon-reservednon-reservedreserved
VARYINGnon-reservednon-reservedreserved
VERSIONINGnon-reservednon-reservedreserved
VIEWnon-reservednon-reservednon-reserved
WHENreservednon-reservedreserved
WHENEVERnon-reservednon-reservedreserved
WHEREreservednon-reservedreserved
WHILEnon-reservednon-reservedreserved
WIDTH_BUCKETnon-reservednon-reservedreserved
WINDOWnon-reservednon-reservedreserved
WITHreservednon-reservedreserved
WITHINnon-reservednon-reservedreserved
WITHOUTnon-reservednon-reservedreserved
WORKnon-reservednon-reservednon-reserved
WRITEnon-reservednon-reservednon-reserved
YEARnon-reservednon-reservedreserved
ZONEnon-reservednon-reservednon-reserved
diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 76a13c5e2478f..f3040f1c89843 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -44,6 +44,11 @@ grammar SqlBase; return true; } } + + /** + * When true, ANSI SQL parsing mode is enabled. + */ + public boolean ansi = false; } singleStatement @@ -728,14 +733,15 @@ qualifiedName identifier : strictIdentifier - | ANTI | FULL | INNER | LEFT | SEMI | RIGHT | NATURAL | JOIN | CROSS | ON - | UNION | INTERSECT | EXCEPT | SETMINUS + | {ansi}? ansiReserved + | {!ansi}? defaultReserved ; strictIdentifier - : IDENTIFIER #unquotedIdentifier - | quotedIdentifier #quotedIdentifierAlternative - | nonReserved #unquotedIdentifier + : IDENTIFIER #unquotedIdentifier + | quotedIdentifier #quotedIdentifierAlternative + | {ansi}? ansiNonReserved #unquotedIdentifier + | {!ansi}? nonReserved #unquotedIdentifier ; quotedIdentifier @@ -752,40 +758,67 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; +// NOTE: You must follow a rule below when you add a new ANTLR taken in this file: +// - All the ANTLR tokens = UNION(`ansiReserved`, `ansiNonReserved`) = UNION(`defaultReserved`, `nonReserved`) +// +// Let's say you add a new token `NEWTOKEN` and this is not reserved regardless of a `spark.sql.parser.ansi.enabled` +// value. In this case, you must add a token `NEWTOKEN` in both `ansiNonReserved` and `nonReserved`. + +// The list of the reserved keywords when `spark.sql.parser.ansi.enabled` is true. Currently, we only reserve +// the ANSI keywords that almost all the ANSI SQL standards (SQL-92, SQL-99, SQL-2003, SQL-2008, SQL-2011, +// and SQL-2016) and PostgreSQL reserve. +ansiReserved + : ALL | AND | ANTI | ANY | AS | AUTHORIZATION | BOTH | CASE | CAST | CHECK | COLLATE | COLUMN | CONSTRAINT | CREATE + | CROSS | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | CURRENT_USER | DISTINCT | ELSE | END | EXCEPT | FALSE + | FETCH | FOR | FOREIGN | FROM | FULL | GRANT | GROUP | HAVING | IN | INNER | INTERSECT | INTO | JOIN | IS + | LEADING | LEFT | NATURAL | NOT | NULL | ON | ONLY | OR | ORDER | OUTER | OVERLAPS | PRIMARY | REFERENCES | RIGHT + | SELECT | SEMI | SESSION_USER | SETMINUS | SOME | TABLE | THEN | TO | TRAILING | UNION | UNIQUE | USER | USING + | WHEN | WHERE | WITH + ; + + +// The list of the non-reserved keywords when `spark.sql.parser.ansi.enabled` is true. +ansiNonReserved + : ADD | AFTER | ALTER | ANALYZE | ARCHIVE | ARRAY | ASC | AT | BETWEEN | BUCKET | BUCKETS | BY | CACHE | CASCADE + | CHANGE | CLEAR | CLUSTER | CLUSTERED | CODEGEN | COLLECTION | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS + | COMPUTE | CONCATENATE | COST | CUBE | CURRENT | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE + | DELIMITED | DESC | DESCRIBE | DFS | DIRECTORIES | DIRECTORY | DISTRIBUTE | DIV | DROP | ESCAPED | EXCHANGE + | EXISTS | EXPLAIN | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FORMAT + | FORMATTED | FUNCTION | FUNCTIONS | GLOBAL | GROUPING | IF | IGNORE | IMPORT | INDEX | INDEXES | INPATH + | INPUTFORMAT | INSERT | INTERVAL | ITEMS | KEYS | LAST | LATERAL | LAZY | LIKE | LIMIT | LINES | LIST | LOAD + | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NULLS | OF | OPTION | OPTIONS | OUT + | OUTPUTFORMAT | OVER | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENT | PERCENTLIT | PIVOT | PRECEDING + | PRINCIPALS | PURGE | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE + | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SCHEMA | SEPARATED | SERDE + | SERDEPROPERTIES | SET | SETS | SHOW | SKEWED | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT + | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY | TERMINATED | TOUCH | TRANSACTION | TRANSACTIONS | TRANSFORM + | TRUE | TRUNCATE | UNARCHIVE | UNBOUNDED | UNCACHE | UNLOCK | UNSET | USE | VALUES | VIEW | WINDOW + ; + +defaultReserved + : ANTI | CROSS | EXCEPT | FULL | INNER | INTERSECT | JOIN | LEFT | NATURAL | ON | RIGHT | SEMI | SETMINUS | UNION + | USING + ; + nonReserved - : SHOW | TABLES | COLUMNS | COLUMN | PARTITIONS | FUNCTIONS | DATABASES - | ADD - | OVER | PARTITION | RANGE | ROWS | PRECEDING | FOLLOWING | CURRENT | ROW | LAST | FIRST | AFTER - | MAP | ARRAY | STRUCT - | PIVOT | LATERAL | WINDOW | REDUCE | TRANSFORM | SERDE | SERDEPROPERTIES | RECORDREADER - | DELIMITED | FIELDS | TERMINATED | COLLECTION | ITEMS | KEYS | ESCAPED | LINES | SEPARATED - | EXTENDED | REFRESH | CLEAR | CACHE | UNCACHE | LAZY | GLOBAL | TEMPORARY | OPTIONS - | GROUPING | CUBE | ROLLUP - | EXPLAIN | FORMAT | LOGICAL | FORMATTED | CODEGEN | COST - | TABLESAMPLE | USE | TO | BUCKET | PERCENTLIT | OUT | OF - | SET | RESET - | VIEW | REPLACE - | IF - | POSITION - | EXTRACT - | NO | DATA - | START | TRANSACTION | COMMIT | ROLLBACK | IGNORE - | SORT | CLUSTER | DISTRIBUTE | UNSET | TBLPROPERTIES | SKEWED | STORED | DIRECTORIES | LOCATION - | EXCHANGE | ARCHIVE | UNARCHIVE | FILEFORMAT | TOUCH | COMPACT | CONCATENATE | CHANGE - | CASCADE | RESTRICT | BUCKETS | CLUSTERED | SORTED | PURGE | INPUTFORMAT | OUTPUTFORMAT - | DBPROPERTIES | DFS | TRUNCATE | COMPUTE | LIST - | STATISTICS | ANALYZE | PARTITIONED | EXTERNAL | DEFINED | RECORDWRITER - | REVOKE | GRANT | LOCK | UNLOCK | MSCK | REPAIR | RECOVER | EXPORT | IMPORT | LOAD | VALUES | COMMENT | ROLE - | ROLES | COMPACTIONS | PRINCIPALS | TRANSACTIONS | INDEX | INDEXES | LOCKS | OPTION | LOCAL | INPATH - | ASC | DESC | LIMIT | RENAME | SETS - | AT | NULLS | OVERWRITE | ALL | ANY | ALTER | AS | BETWEEN | BY | CREATE | DELETE - | DESCRIBE | DROP | EXISTS | FALSE | FOR | GROUP | IN | INSERT | INTO | IS |LIKE - | NULL | ORDER | OUTER | TABLE | TRUE | WITH | RLIKE - | AND | CASE | CAST | DISTINCT | DIV | ELSE | END | FUNCTION | INTERVAL | MACRO | OR | STRATIFY | THEN - | UNBOUNDED | WHEN - | DATABASE | SELECT | FROM | WHERE | HAVING | TO | TABLE | WITH | NOT - | DIRECTORY - | BOTH | LEADING | TRAILING + : ADD | AFTER | ALL | ALTER | ANALYZE | AND | ANY | ARCHIVE | ARRAY | AS | ASC | AT | AUTHORIZATION | BETWEEN + | BOTH | BUCKET | BUCKETS | BY | CACHE | CASCADE | CASE | CAST | CHANGE | CHECK | CLEAR | CLUSTER | CLUSTERED + | CODEGEN | COLLATE | COLLECTION | COLUMN | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS | COMPUTE + | CONCATENATE | CONSTRAINT | COST | CREATE | CUBE | CURRENT | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP + | CURRENT_USER | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE | DELIMITED | DESC | DESCRIBE | DFS + | DIRECTORIES | DIRECTORY | DISTINCT | DISTRIBUTE | DIV | DROP | ELSE | END | ESCAPED | EXCHANGE | EXISTS | EXPLAIN + | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FALSE | FETCH | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FOR | FOREIGN + | FORMAT | FORMATTED | FROM | FUNCTION | FUNCTIONS | GLOBAL | GRANT | GROUP | GROUPING | HAVING | IF | IGNORE + | IMPORT | IN | INDEX | INDEXES | INPATH | INPUTFORMAT | INSERT | INTERVAL | INTO | IS | ITEMS | KEYS | LAST + | LATERAL | LAZY | LEADING | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO + | MAP | MSCK | NO | NOT | NULL | NULLS | OF | ONLY | OPTION | OPTIONS | OR | ORDER | OUT | OUTER | OUTPUTFORMAT + | OVER | OVERLAPS | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT | PIVOT | POSITION | PRECEDING + | PRIMARY | PRINCIPALS | PURGE | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH + | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS + | SELECT | SEPARATED | SERDE | SERDEPROPERTIES | SESSION_USER | SET | SETS | SHOW | SKEWED | SOME | SORT | SORTED + | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLE | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY + | TERMINATED | THEN | TO | TOUCH | TRAILING | TRANSACTION | TRANSACTIONS | TRANSFORM | TRUE | TRUNCATE | UNARCHIVE + | UNBOUNDED | UNCACHE | UNLOCK | UNIQUE | UNSET | USE | USER | VALUES | VIEW | WHEN | WHERE | WINDOW | WITH ; SELECT: 'SELECT'; @@ -1022,6 +1055,24 @@ OPTION: 'OPTION'; ANTI: 'ANTI'; LOCAL: 'LOCAL'; INPATH: 'INPATH'; +AUTHORIZATION: 'AUTHORIZATION'; +CHECK: 'CHECK'; +COLLATE: 'COLLATE'; +CONSTRAINT: 'CONSTRAINT'; +CURRENT_DATE: 'CURRENT_DATE'; +CURRENT_TIME: 'CURRENT_TIME'; +CURRENT_TIMESTAMP: 'CURRENT_TIMESTAMP'; +CURRENT_USER: 'CURRENT_USER'; +FETCH: 'FETCH'; +FOREIGN: 'FOREIGN'; +ONLY: 'ONLY'; +OVERLAPS: 'OVERLAPS'; +PRIMARY: 'PRIMARY'; +REFERENCES: 'REFERENCES'; +SESSION_USER: 'SESSION_USER'; +SOME: 'SOME'; +UNIQUE: 'UNIQUE'; +USER: 'USER'; STRING : '\'' ( ~('\''|'\\') | ('\\' .) )* '\'' diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala index ffc64f78e3003..31917ab9a5579 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParseDriver.scala @@ -92,6 +92,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { lexer.removeErrorListeners() lexer.addErrorListener(ParseErrorListener) lexer.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced + lexer.ansi = SQLConf.get.ansiParserEnabled val tokenStream = new CommonTokenStream(lexer) val parser = new SqlBaseParser(tokenStream) @@ -99,6 +100,7 @@ abstract class AbstractSqlParser extends ParserInterface with Logging { parser.removeErrorListeners() parser.addErrorListener(ParseErrorListener) parser.legacy_setops_precedence_enbled = SQLConf.get.setOpsPrecedenceEnforced + parser.ansi = SQLConf.get.ansiParserEnabled try { try { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5900a72f3387e..7b0bd76c90cc9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -314,6 +314,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val ANSI_SQL_PARSER = + buildConf("spark.sql.parser.ansi.enabled") + .doc("When true, tries to conform to ANSI SQL syntax.") + .booleanConf + .createWithDefault(false) + val ESCAPED_STRING_LITERALS = buildConf("spark.sql.parser.escapedStringLiterals") .internal() .doc("When true, string literals (including regex patterns) remain escaped in our SQL " + @@ -1869,6 +1875,8 @@ class SQLConf extends Serializable with Logging { def constraintPropagationEnabled: Boolean = getConf(CONSTRAINT_PROPAGATION_ENABLED) + def ansiParserEnabled: Boolean = getConf(ANSI_SQL_PARSER) + def escapedStringLiterals: Boolean = getConf(ESCAPED_STRING_LITERALS) def fileCompressionFactor: Double = getConf(FILE_COMRESSION_FACTOR) From 4a7c0077214e0aefe0d783424dda49d563ab0303 Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Fri, 1 Mar 2019 12:34:15 -0800 Subject: [PATCH 24/70] [SPARK-26215][SQL][FOLLOW-UP][MINOR] Fix the warning from ANTR4 ## What changes were proposed in this pull request? I see the following new warning from ANTR4 after SPARK-26215 after it added `SCHEMA` keyword in the reserved/unreserved list. This is a minor PR to cleanup the warning. ``` WARNING] warning(125): org/apache/spark/sql/catalyst/parser/SqlBase.g4:784:90: implicit definition of token SCHEMA in parser [WARNING] .../apache/spark/org/apache/spark/sql/catalyst/parser/SqlBase.g4 [784:90]: implicit definition of token SCHEMA in parser ``` ## How was this patch tested? Manually built catalyst after the fix to verify Closes #23897 from dilipbiswal/minor_parser_token. Authored-by: Dilip Biswal Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/parser/SqlBase.g4 | 18 ++++++++++++------ .../sql/execution/SparkSqlParserSuite.scala | 8 ++++++++ 2 files changed, 20 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index f3040f1c89843..20f2ffb5fa7da 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -82,11 +82,11 @@ singleTableSchema statement : query #statementDefault | USE db=identifier #use - | CREATE DATABASE (IF NOT EXISTS)? identifier + | CREATE database (IF NOT EXISTS)? identifier (COMMENT comment=STRING)? locationSpec? (WITH DBPROPERTIES tablePropertyList)? #createDatabase - | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties - | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase + | ALTER database identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties + | DROP database (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider ((OPTIONS options=tablePropertyList) | (PARTITIONED BY partitionColumnNames=identifierList) | @@ -167,7 +167,7 @@ statement (LIKE? (qualifiedName | pattern=STRING))? #showFunctions | SHOW CREATE TABLE tableIdentifier #showCreateTable | (DESC | DESCRIBE) FUNCTION EXTENDED? describeFuncName #describeFunction - | (DESC | DESCRIBE) DATABASE EXTENDED? identifier #describeDatabase + | (DESC | DESCRIBE) database EXTENDED? identifier #describeDatabase | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable | REFRESH TABLE tableIdentifier #refreshTable @@ -278,6 +278,11 @@ partitionVal : identifier (EQ constant)? ; +database + : DATABASE + | SCHEMA + ; + describeFuncName : qualifiedName | STRING @@ -758,7 +763,7 @@ number | MINUS? BIGDECIMAL_LITERAL #bigDecimalLiteral ; -// NOTE: You must follow a rule below when you add a new ANTLR taken in this file: +// NOTE: You must follow a rule below when you add a new ANTLR token in this file: // - All the ANTLR tokens = UNION(`ansiReserved`, `ansiNonReserved`) = UNION(`defaultReserved`, `nonReserved`) // // Let's say you add a new token `NEWTOKEN` and this is not reserved regardless of a `spark.sql.parser.ansi.enabled` @@ -1022,7 +1027,8 @@ SORTED: 'SORTED'; PURGE: 'PURGE'; INPUTFORMAT: 'INPUTFORMAT'; OUTPUTFORMAT: 'OUTPUTFORMAT'; -DATABASE: 'DATABASE' | 'SCHEMA'; +SCHEMA: 'SCHEMA'; +DATABASE: 'DATABASE'; DATABASES: 'DATABASES' | 'SCHEMAS'; DFS: 'DFS'; TRUNCATE: 'TRUNCATE'; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 31b9bcdafbab8..038871cb1fc12 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -387,4 +387,12 @@ class SparkSqlParserSuite extends AnalysisTest { "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", "Operation not allowed: ALTER VIEW ... AS FROM ... [INSERT INTO ...]+") } + + test("database and schema tokens are interchangeable") { + assertEqual("CREATE DATABASE foo", parser.parsePlan("CREATE SCHEMA foo")) + assertEqual("DROP DATABASE foo", parser.parsePlan("DROP SCHEMA foo")) + assertEqual("ALTER DATABASE foo SET DBPROPERTIES ('x' = 'y')", + parser.parsePlan("ALTER SCHEMA foo SET DBPROPERTIES ('x' = 'y')")) + assertEqual("DESC DATABASE foo", parser.parsePlan("DESC SCHEMA foo")) + } } From 6cb92345a8755389a993efdad518c5a90c0c47fe Mon Sep 17 00:00:00 2001 From: Dilip Biswal Date: Sat, 2 Mar 2019 11:21:23 +0800 Subject: [PATCH 25/70] [SPARK-26982][SQL] Enhance describe framework to describe the output of a query. Currently we can use `df.printSchema` to discover the schema information for a query. We should have a way to describe the output schema of a query using SQL interface. Example: DESCRIBE SELECT * FROM desc_table DESCRIBE QUERY SELECT * FROM desc_table ```SQL spark-sql> create table desc_table (c1 int comment 'c1-comment', c2 decimal comment 'c2-comment', c3 string); spark-sql> desc select * from desc_table; c1 int c1-comment c2 decimal(10,0) c2-comment c3 string NULL ``` Added a new test under SQLQueryTestSuite and SparkSqlParserSuite Closes #23883 from dilipbiswal/dkb_describe_query. Authored-by: Dilip Biswal Signed-off-by: Wenchen Fan --- .../sql-reserved-and-non-reserved-keywords.md | 1 + .../spark/sql/catalyst/parser/SqlBase.g4 | 10 +- .../sql/catalyst/parser/AstBuilder.scala | 4 + .../parser/TableIdentifierParserSuite.scala | 4 +- .../spark/sql/execution/HiveResult.scala | 6 +- .../spark/sql/execution/SparkSqlParser.scala | 7 + .../spark/sql/execution/command/tables.scala | 81 ++++++--- .../sql-tests/inputs/describe-query.sql | 27 +++ .../sql-tests/results/describe-query.sql.out | 171 ++++++++++++++++++ .../apache/spark/sql/SQLQueryTestSuite.scala | 4 +- .../sql/execution/SparkSqlParserSuite.scala | 17 +- .../hive/execution/HiveComparisonTest.scala | 4 +- 12 files changed, 292 insertions(+), 44 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/describe-query.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/describe-query.sql.out diff --git a/docs/sql-reserved-and-non-reserved-keywords.md b/docs/sql-reserved-and-non-reserved-keywords.md index 321fb3f00acbd..53eb9988f6c88 100644 --- a/docs/sql-reserved-and-non-reserved-keywords.md +++ b/docs/sql-reserved-and-non-reserved-keywords.md @@ -156,6 +156,7 @@ The list of reserved and non-reserved keywords can change according to the confi DEREFnon-reservednon-reservedreserved DESCnon-reservednon-reservednon-reserved DESCRIBEnon-reservednon-reservedreserved + QUERYnon-reservednon-reservednon-reserved DESCRIPTORnon-reservednon-reservednon-reserved DETERMINISTICnon-reservednon-reservedreserved DFSnon-reservednon-reservednon-reserved diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 20f2ffb5fa7da..0963ecf00c3d1 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -170,6 +170,7 @@ statement | (DESC | DESCRIBE) database EXTENDED? identifier #describeDatabase | (DESC | DESCRIBE) TABLE? option=(EXTENDED | FORMATTED)? tableIdentifier partitionSpec? describeColName? #describeTable + | (DESC | DESCRIBE) QUERY? queryToDesc #describeQuery | REFRESH TABLE tableIdentifier #refreshTable | REFRESH (STRING | .*?) #refreshResource | CACHE LAZY? TABLE tableIdentifier @@ -259,6 +260,10 @@ query : ctes? queryNoWith ; +queryToDesc + : queryTerm queryOrganization + ; + insertInto : INSERT OVERWRITE TABLE tableIdentifier (partitionSpec (IF NOT EXISTS)?)? #insertOverwriteTable | INSERT INTO TABLE? tableIdentifier partitionSpec? #insertIntoTable @@ -793,7 +798,7 @@ ansiNonReserved | INPUTFORMAT | INSERT | INTERVAL | ITEMS | KEYS | LAST | LATERAL | LAZY | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NULLS | OF | OPTION | OPTIONS | OUT | OUTPUTFORMAT | OVER | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENT | PERCENTLIT | PIVOT | PRECEDING - | PRINCIPALS | PURGE | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE + | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SCHEMA | SEPARATED | SERDE | SERDEPROPERTIES | SET | SETS | SHOW | SKEWED | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY | TERMINATED | TOUCH | TRANSACTION | TRANSACTIONS | TRANSFORM @@ -818,7 +823,7 @@ nonReserved | LATERAL | LAZY | LEADING | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NOT | NULL | NULLS | OF | ONLY | OPTION | OPTIONS | OR | ORDER | OUT | OUTER | OUTPUTFORMAT | OVER | OVERLAPS | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT | PIVOT | POSITION | PRECEDING - | PRIMARY | PRINCIPALS | PURGE | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH + | PRIMARY | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SELECT | SEPARATED | SERDE | SERDEPROPERTIES | SESSION_USER | SET | SETS | SHOW | SKEWED | SOME | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLE | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY @@ -896,6 +901,7 @@ WITH: 'WITH'; VALUES: 'VALUES'; CREATE: 'CREATE'; TABLE: 'TABLE'; +QUERY: 'QUERY'; DIRECTORY: 'DIRECTORY'; VIEW: 'VIEW'; REPLACE: 'REPLACE'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index aa6d8cf7e5ad0..d2f26e700b6fb 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -122,6 +122,10 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + override def visitQueryToDesc(ctx: QueryToDescContext): LogicalPlan = withOrigin(ctx) { + plan(ctx.queryTerm).optionalMap(ctx.queryOrganization)(withQueryResultClauses) + } + /** * Create a named logical plan. * diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index ff0de0fb7c1f0..489b7f328f8fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -47,8 +47,8 @@ class TableIdentifierParserSuite extends SparkFunSuite { "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", - "of", "order", "out", "outer", "partition", "percent", "procedure", "range", "reads", "revoke", - "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", + "of", "order", "out", "outer", "partition", "percent", "procedure", "query", "range", "reads", + "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index c90b254a6d121..a369b49777a5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -21,8 +21,8 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExecutedCommandExec, ShowTablesCommand} +import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -35,7 +35,7 @@ object HiveResult { * `SparkSQLDriver` for CLI applications. */ def hiveResultString(executedPlan: SparkPlan): Seq[String] = executedPlan match { - case ExecutedCommandExec(desc: DescribeTableCommand) => + case ExecutedCommandExec(_: DescribeCommandBase) => // If it is a describe command for a Hive table, we want to have the output format // be similar with Hive. executedPlan.executeCollectPublic().map { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 8deb55b00a9d3..c17cf5de9066b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -369,6 +369,13 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + /** + * Create a [[DescribeQueryCommand]] logical command. + */ + override def visitDescribeQuery(ctx: DescribeQueryContext): LogicalPlan = withOrigin(ctx) { + DescribeQueryCommand(visitQueryToDesc(ctx.queryToDesc())) + } + /** * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index d24e66e583857..8b70e336c14bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -29,12 +29,12 @@ import org.apache.hadoop.fs.{FileContext, FsConstants, Path} import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute} +import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, UnresolvedAttribute, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.catalog.CatalogTableType._ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.Histogram +import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIdentifier} import org.apache.spark.sql.execution.datasources.{DataSource, PartitioningUtils} import org.apache.spark.sql.execution.datasources.csv.CSVFileFormat @@ -494,6 +494,34 @@ case class TruncateTableCommand( } } +abstract class DescribeCommandBase extends RunnableCommand { + override val output: Seq[Attribute] = Seq( + // Column names are based on Hive. + AttributeReference("col_name", StringType, nullable = false, + new MetadataBuilder().putString("comment", "name of the column").build())(), + AttributeReference("data_type", StringType, nullable = false, + new MetadataBuilder().putString("comment", "data type of the column").build())(), + AttributeReference("comment", StringType, nullable = true, + new MetadataBuilder().putString("comment", "comment of the column").build())() + ) + + protected def describeSchema( + schema: StructType, + buffer: ArrayBuffer[Row], + header: Boolean): Unit = { + if (header) { + append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) + } + schema.foreach { column => + append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) + } + } + + protected def append( + buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = { + buffer += Row(column, dataType, comment) + } +} /** * Command that looks like * {{{ @@ -504,17 +532,7 @@ case class DescribeTableCommand( table: TableIdentifier, partitionSpec: TablePartitionSpec, isExtended: Boolean) - extends RunnableCommand { - - override val output: Seq[Attribute] = Seq( - // Column names are based on Hive. - AttributeReference("col_name", StringType, nullable = false, - new MetadataBuilder().putString("comment", "name of the column").build())(), - AttributeReference("data_type", StringType, nullable = false, - new MetadataBuilder().putString("comment", "data type of the column").build())(), - AttributeReference("comment", StringType, nullable = true, - new MetadataBuilder().putString("comment", "comment of the column").build())() - ) + extends DescribeCommandBase { override def run(sparkSession: SparkSession): Seq[Row] = { val result = new ArrayBuffer[Row] @@ -603,22 +621,31 @@ case class DescribeTableCommand( } table.storage.toLinkedHashMap.foreach(s => append(buffer, s._1, s._2, "")) } +} - private def describeSchema( - schema: StructType, - buffer: ArrayBuffer[Row], - header: Boolean): Unit = { - if (header) { - append(buffer, s"# ${output.head.name}", output(1).name, output(2).name) - } - schema.foreach { column => - append(buffer, column.name, column.dataType.simpleString, column.getComment().orNull) - } - } +/** + * Command that looks like + * {{{ + * DESCRIBE [QUERY] statement + * }}} + * + * Parameter 'statement' can be one of the following types : + * 1. SELECT statements + * 2. SELECT statements inside set operators (UNION, INTERSECT etc) + * 3. VALUES statement. + * 4. TABLE statement. Example : TABLE table_name + * 5. statements of the form 'FROM table SELECT *' + * + * TODO : support CTEs. + */ +case class DescribeQueryCommand(query: LogicalPlan) + extends DescribeCommandBase { - private def append( - buffer: ArrayBuffer[Row], column: String, dataType: String, comment: String): Unit = { - buffer += Row(column, dataType, comment) + override def run(sparkSession: SparkSession): Seq[Row] = { + val result = new ArrayBuffer[Row] + val queryExecution = sparkSession.sessionState.executePlan(query) + describeSchema(queryExecution.analyzed.schema, result, header = false) + result } } diff --git a/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql new file mode 100644 index 0000000000000..bc144d01cee64 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/describe-query.sql @@ -0,0 +1,27 @@ +-- Test tables +CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET; +CREATE table desc_temp2 (key int, val string) USING PARQUET; + +-- Simple Describe query +DESC SELECT key, key + 1 as plusone FROM desc_temp1; +DESC QUERY SELECT * FROM desc_temp2; +DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key; +DESC SELECT 10.00D as col1; +DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE); +DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2); +DESC QUERY FROM desc_temp1 a SELECT *; + + +-- Error cases. +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s; +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s; +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1'); +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2; +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * + insert into desc_temp2 select *; + +-- cleanup +DROP TABLE desc_temp1; +DROP TABLE desc_temp2; diff --git a/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out new file mode 100644 index 0000000000000..36cb314884779 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/describe-query.sql.out @@ -0,0 +1,171 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 16 + + +-- !query 0 +CREATE table desc_temp1 (key int COMMENT 'column_comment', val string) USING PARQUET +-- !query 0 schema +struct<> +-- !query 0 output + + + +-- !query 1 +CREATE table desc_temp2 (key int, val string) USING PARQUET +-- !query 1 schema +struct<> +-- !query 1 output + + + +-- !query 2 +DESC SELECT key, key + 1 as plusone FROM desc_temp1 +-- !query 2 schema +struct +-- !query 2 output +key int column_comment +plusone int + + +-- !query 3 +DESC QUERY SELECT * FROM desc_temp2 +-- !query 3 schema +struct +-- !query 3 output +key int +val string + + +-- !query 4 +DESC SELECT key, COUNT(*) as count FROM desc_temp1 group by key +-- !query 4 schema +struct +-- !query 4 output +key int column_comment +count bigint + + +-- !query 5 +DESC SELECT 10.00D as col1 +-- !query 5 schema +struct +-- !query 5 output +col1 double + + +-- !query 6 +DESC QUERY SELECT key FROM desc_temp1 UNION ALL select CAST(1 AS DOUBLE) +-- !query 6 schema +struct +-- !query 6 output +key double + + +-- !query 7 +DESC QUERY VALUES(1.00D, 'hello') as tab1(col1, col2) +-- !query 7 schema +struct +-- !query 7 output +col1 double +col2 string + + +-- !query 8 +DESC QUERY FROM desc_temp1 a SELECT * +-- !query 8 schema +struct +-- !query 8 output +key int column_comment +val string + + +-- !query 9 +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s +-- !query 9 schema +struct<> +-- !query 9 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'AS' expecting {, '.'}(line 1, pos 12) + +== SQL == +DESC WITH s AS (SELECT 'hello' as col1) SELECT * FROM s +------------^^^ + + +-- !query 10 +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s +-- !query 10 schema +struct<> +-- !query 10 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 's' expecting {, '.'}(line 1, pos 20) + +== SQL == +DESCRIBE QUERY WITH s AS (SELECT * from desc_temp1) SELECT * FROM s +--------------------^^^ + + +-- !query 11 +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1') +-- !query 11 schema +struct<> +-- !query 11 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21) + +== SQL == +DESCRIBE INSERT INTO desc_temp1 values (1, 'val1') +---------------------^^^ + + +-- !query 12 +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2 +-- !query 12 schema +struct<> +-- !query 12 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'desc_temp1' expecting {, '.'}(line 1, pos 21) + +== SQL == +DESCRIBE INSERT INTO desc_temp1 SELECT * FROM desc_temp2 +---------------------^^^ + + +-- !query 13 +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * + insert into desc_temp2 select * +-- !query 13 schema +struct<> +-- !query 13 output +org.apache.spark.sql.catalyst.parser.ParseException + +mismatched input 'insert' expecting {, '(', ',', 'SELECT', 'WHERE', 'GROUP', 'ORDER', 'HAVING', 'LIMIT', 'JOIN', 'CROSS', 'INNER', 'LEFT', 'RIGHT', 'FULL', 'NATURAL', 'PIVOT', 'LATERAL', 'WINDOW', 'UNION', 'EXCEPT', 'MINUS', 'INTERSECT', 'SORT', 'CLUSTER', 'DISTRIBUTE', 'ANTI'}(line 3, pos 5) + +== SQL == +DESCRIBE + FROM desc_temp1 a + insert into desc_temp1 select * +-----^^^ + insert into desc_temp2 select * + + +-- !query 14 +DROP TABLE desc_temp1 +-- !query 14 schema +struct<> +-- !query 14 output + + + +-- !query 15 +DROP TABLE desc_temp2 +-- !query 15 schema +struct<> +-- !query 15 output + diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 24b312348bd67..62f3f98bf28ae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.util.{fileToString, stringToFile} import org.apache.spark.sql.execution.HiveResult.hiveResultString -import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeTableCommand} +import org.apache.spark.sql.execution.command.{DescribeColumnCommand, DescribeCommandBase} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.StructType @@ -277,7 +277,7 @@ class SQLQueryTestSuite extends QueryTest with SharedSQLContext { // Returns true if the plan is supposed to be sorted. def isSorted(plan: LogicalPlan): Boolean = plan match { case _: Join | _: Aggregate | _: Generate | _: Sample | _: Distinct => false - case _: DescribeTableCommand | _: DescribeColumnCommand => true + case _: DescribeCommandBase | _: DescribeColumnCommand => true case PhysicalOperation(_, _, Sort(_, true, _)) => true case _ => plan.children.iterator.exists(isSorted) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 038871cb1fc12..425a96b871ad2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -240,15 +240,20 @@ class SparkSqlParserSuite extends AnalysisTest { } test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { + assertEqual("describe t", + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = false)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = false)) assertEqual("describe table extended t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true)) assertEqual("describe table formatted t", - DescribeTableCommand( - TableIdentifier("t"), Map.empty, isExtended = true)) + DescribeTableCommand(TableIdentifier("t"), Map.empty, isExtended = true)) + } + + test("describe query") { + val query = "SELECT * FROM t" + assertEqual("DESCRIBE QUERY " + query, DescribeQueryCommand(parser.parsePlan(query))) + assertEqual("DESCRIBE " + query, DescribeQueryCommand(parser.parsePlan(query))) } test("describe table column") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala index 66426824573c6..a4587abbf389d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveComparisonTest.scala @@ -172,7 +172,7 @@ abstract class HiveComparisonTest // and does not return it as a query answer. case _: SetCommand => Seq("0") case _: ExplainCommand => answer - case _: DescribeTableCommand | ShowColumnsCommand(_, _) => + case _: DescribeCommandBase | ShowColumnsCommand(_, _) => // Filter out non-deterministic lines and lines which do not have actual results but // can introduce problems because of the way Hive formats these lines. // Then, remove empty lines. Do not sort the results. @@ -375,7 +375,7 @@ abstract class HiveComparisonTest if ((!hiveQuery.logical.isInstanceOf[ExplainCommand]) && (!hiveQuery.logical.isInstanceOf[ShowFunctionsCommand]) && (!hiveQuery.logical.isInstanceOf[DescribeFunctionCommand]) && - (!hiveQuery.logical.isInstanceOf[DescribeTableCommand]) && + (!hiveQuery.logical.isInstanceOf[DescribeCommandBase]) && preparedHive != catalyst) { val hivePrintOut = s"== HIVE - ${preparedHive.size} row(s) ==" +: preparedHive From f0d9915ab9e30764d8becc4199acdbe3bbe92ed0 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 22 Mar 2019 13:58:54 -0700 Subject: [PATCH 26/70] [SPARK-27108][SQL] Add parsed SQL plans for create, CTAS. This moves parsing `CREATE TABLE ... USING` statements into catalyst. Catalyst produces logical plans with the parsed information and those plans are converted to v1 `DataSource` plans in `DataSourceAnalysis`. This prepares for adding v2 create plans that should receive the information parsed from SQL without being translated to v1 plans first. This also makes it possible to parse in catalyst instead of breaking the parser across the abstract `AstBuilder` in catalyst and `SparkSqlParser` in core. For more information, see the [mailing list thread](https://lists.apache.org/thread.html/54f4e1929ceb9a2b0cac7cb058000feb8de5d6c667b2e0950804c613%3Cdev.spark.apache.org%3E). This uses existing tests to catch regressions. This introduces no behavior changes. Closes #24029 from rdblue/SPARK-27108-add-parsed-create-logical-plans. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../sql/catalyst/parser/AstBuilder.scala | 193 ++++++++++- .../logical/sql/CreateTableStatement.scala | 66 ++++ .../plans/logical/sql/ParsedStatement.scala | 44 +++ .../sql/catalyst/parser/DDLParserSuite.scala | 318 ++++++++++++++++++ .../spark/sql/execution/SparkSqlParser.scala | 233 ++----------- .../datasources/DataSourceResolution.scala | 112 ++++++ .../internal/BaseSessionStateBuilder.scala | 1 + .../sql/execution/SparkSqlParserSuite.scala | 13 - .../execution/command/DDLParserSuite.scala | 251 +------------- .../command/PlanResolutionSuite.scala | 257 ++++++++++++++ .../sql/hive/HiveSessionStateBuilder.scala | 1 + 11 files changed, 1031 insertions(+), 458 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2f26e700b6fb..3732b437c8511 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -31,12 +31,14 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.catalyst.catalog.CatalogStorageFormat +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.CalendarInterval @@ -1868,4 +1870,193 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging val structField = StructField(identifier.getText, typedVisit(dataType), nullable = true) if (STRING == null) structField else structField.withComment(string(STRING)) } + + /** + * Create location string. + */ + override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { + string(ctx.STRING) + } + + /** + * Create a [[BucketSpec]]. + */ + override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { + BucketSpec( + ctx.INTEGER_VALUE.getText.toInt, + visitIdentifierList(ctx.identifierList), + Option(ctx.orderedIdentifierList) + .toSeq + .flatMap(_.orderedIdentifier.asScala) + .map { orderedIdCtx => + Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => + if (dir.toLowerCase(Locale.ROOT) != "asc") { + operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) + } + } + + orderedIdCtx.identifier.getText + }) + } + + /** + * Convert a table property list into a key-value map. + * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. + */ + override def visitTablePropertyList( + ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { + val properties = ctx.tableProperty.asScala.map { property => + val key = visitTablePropertyKey(property.key) + val value = visitTablePropertyValue(property.value) + key -> value + } + // Check for duplicate property names. + checkDuplicateKeys(properties, ctx) + properties.toMap + } + + /** + * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. + */ + def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.collect { case (key, null) => key } + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props + } + + /** + * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. + */ + def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { + val props = visitTablePropertyList(ctx) + val badKeys = props.filter { case (_, v) => v != null }.keys + if (badKeys.nonEmpty) { + operationNotAllowed( + s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) + } + props.keys.toSeq + } + + /** + * A table property key can either be String or a collection of dot separated elements. This + * function extracts the property key based on whether its a string literal or a table property + * identifier. + */ + override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { + if (key.STRING != null) { + string(key.STRING) + } else { + key.getText + } + } + + /** + * A table property value can be String, Integer, Boolean or Decimal. This function extracts + * the property value based on whether its a string, integer, boolean or decimal literal. + */ + override def visitTablePropertyValue(value: TablePropertyValueContext): String = { + if (value == null) { + null + } else if (value.STRING != null) { + string(value.STRING) + } else if (value.booleanValue != null) { + value.getText.toLowerCase(Locale.ROOT) + } else { + value.getText + } + } + + /** + * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). + */ + type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + + /** + * Validate a create table statement and return the [[TableIdentifier]]. + */ + override def visitCreateTableHeader( + ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { + val temporary = ctx.TEMPORARY != null + val ifNotExists = ctx.EXISTS != null + if (temporary && ifNotExists) { + operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) + } + (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + } + + /** + * Create a table, returning a [[CreateTableStatement]] logical plan. + * + * Expected format: + * {{{ + * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name + * USING table_provider + * create_table_clauses + * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] + * }}} + */ + override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { + val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + if (external) { + operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) + } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + + val schema = Option(ctx.colTypeList()).map(createSchema) + val partitionCols: Seq[String] = + Option(ctx.partitionColumnNames).map(visitIdentifierList).getOrElse(Nil) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + + val provider = ctx.tableProvider.qualifiedName.getText + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) + val comment = Option(ctx.comment).map(string) + + Option(ctx.query).map(plan) match { + case Some(_) if temp => + operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + + case Some(_) if schema.isDefined => + operationNotAllowed( + "Schema may not be specified in a Create Table As Select (CTAS) statement", + ctx) + + case Some(query) => + CreateTableAsSelectStatement( + table, query, partitionCols, bucketSpec, properties, provider, options, location, comment, + ifNotExists = ifNotExists) + + case None if temp => + // CREATE TEMPORARY TABLE ... USING ... is not supported by the catalyst parser. + // Use CREATE TEMPORARY VIEW ... USING ... instead. + operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) + + case _ => + CreateTableStatement(table, schema.getOrElse(new StructType), partitionCols, bucketSpec, + properties, provider, options, location, comment, ifNotExists = ifNotExists) + } + } + } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala new file mode 100644 index 0000000000000..c734968e838db --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.StructType + +/** + * A CREATE TABLE command, as parsed from SQL. + * + * This is a metadata-only command and is not used to write data to the created table. + */ +case class CreateTableStatement( + table: TableIdentifier, + tableSchema: StructType, + partitioning: Seq[String], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} + +/** + * A CREATE TABLE AS SELECT command, as parsed from SQL. + */ +case class CreateTableAsSelectStatement( + table: TableIdentifier, + asSelect: LogicalPlan, + partitioning: Seq[String], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq(asSelect) +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala new file mode 100644 index 0000000000000..510f2a1ba1e6d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A logical plan node that contains exactly what was parsed from SQL. + * + * This is used to hold information parsed from SQL when there are multiple implementations of a + * query or command. For example, CREATE TABLE may be implemented by different nodes for v1 and v2. + * Instead of parsing directly to a v1 CreateTable that keeps metadata in CatalogTable, and then + * converting that v1 metadata to the v2 equivalent, the sql [[CreateTableStatement]] plan is + * produced by the parser and converted once into both implementations. + * + * Parsed logical plans are not resolved because they must be converted to concrete logical plans. + * + * Parsed logical plans are located in Catalyst so that as much SQL parsing logic as possible is be + * kept in a [[org.apache.spark.sql.catalyst.parser.AbstractSqlParser]]. + */ +private[sql] abstract class ParsedStatement extends LogicalPlan { + // Redact properties and options when parsed nodes are used by generic methods like toString + override def productIterator: Iterator[Any] = super.productIterator.map { + case mapArg: Map[_, _] => conf.redactOptions(mapArg.asInstanceOf[Map[String, String]]) + case other => other + } + + final override lazy val resolved = false +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala new file mode 100644 index 0000000000000..dae8f582c7716 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.parser + +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class DDLParserSuite extends AnalysisTest { + import CatalystSqlParser._ + + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } + + test("create table using - schema") { + val sql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + + intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", + "no viable alternative at input") + } + + test("create table - with IF NOT EXISTS") { + val sql = "CREATE TABLE IF NOT EXISTS my_tab(a INT, b STRING) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with partitioned by") { + val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + + "USING parquet PARTITIONED BY (a)" + + parsePlan(query) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType)) + assert(create.partitioning == Seq("a")) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with bucket") { + val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" + + parsePlan(query) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.contains(BucketSpec(5, Seq("a"), Seq("b")))) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with comment") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.contains("abc")) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with table properties") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties == Map("test" -> "test")) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with location") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.contains("/tmp/file")) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("2g", Some("1m"))) + assert(create.tableSchema == new StructType().add("a", IntegerType)) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String): String = { + s"CREATE TABLE my_tab(a INT, b STRING) USING parquet $duplicateClause $duplicateClause" + } + + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'"), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'"), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), + "Found duplicate clauses: CLUSTERED BY") + intercept(createTableHeader("PARTITIONED BY (b)"), + "Found duplicate clauses: PARTITIONED BY") + } + + test("support for other types in OPTIONS") { + val sql = + """ + |CREATE TABLE table_name USING json + |OPTIONS (a 1, b 0.1, c TRUE) + """.stripMargin + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("table_name")) + assert(create.tableSchema == new StructType) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "json") + assert(create.options == Map("a" -> "1", "b" -> "0.1", "c" -> "true")) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test CTAS against native tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + parsePlan(sql) match { + case create: CreateTableAsSelectStatement => + assert(create.table == TableIdentifier("page_view", Some("mydb"))) + assert(create.partitioning.isEmpty) + assert(create.bucketSpec.isEmpty) + assert(create.properties == Map("p1" -> "v1", "p2" -> "v2")) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.contains("/user/external/page_view")) + assert(create.comment.contains("This is the staging page view table")) + assert(create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelectStatement].getClass.getName} " + + s"from query, got ${other.getClass.getName}: $sql") + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c17cf5de9066b..b997399007cd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -25,7 +25,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ @@ -376,128 +376,46 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { DescribeQueryCommand(visitQueryToDesc(ctx.queryToDesc())) } - /** - * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). - */ - type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) - - /** - * Validate a create table statement and return the [[TableIdentifier]]. - */ - override def visitCreateTableHeader( - ctx: CreateTableHeaderContext): TableHeader = withOrigin(ctx) { - val temporary = ctx.TEMPORARY != null - val ifNotExists = ctx.EXISTS != null - if (temporary && ifNotExists) { - operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) - } - (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) - } - /** * Create a table, returning a [[CreateTable]] logical plan. * - * Expected format: - * {{{ - * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name - * USING table_provider - * create_table_clauses - * [[AS] select_statement]; + * This is used to produce CreateTempViewUsing from CREATE TEMPORARY TABLE. * - * create_table_clauses (order insensitive): - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] - * }}} + * TODO: Remove this. It is used because CreateTempViewUsing is not a Catalyst plan. + * Either move CreateTempViewUsing into catalyst as a parsed logical plan, or remove it because + * it is deprecated. */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) - if (external) { - operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) - } - - checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) - checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) - checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) - checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) - checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) - checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - - val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) - val provider = ctx.tableProvider.qualifiedName.getText - val schema = Option(ctx.colTypeList()).map(createSchema) - val partitionColumnNames = - Option(ctx.partitionColumnNames) - .map(visitIdentifierList(_).toArray) - .getOrElse(Array.empty[String]) - val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - - val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) - val storage = DataSource.buildStorageFormatFromOptions(options) - if (location.isDefined && storage.locationUri.isDefined) { - throw new ParseException( - "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + - "you can only specify one of them.", ctx) - } - val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI)) - - val tableType = if (customLocation.isDefined) { - CatalogTableType.EXTERNAL + if (!temp || ctx.query != null) { + super.visitCreateTable(ctx) } else { - CatalogTableType.MANAGED - } - - val tableDesc = CatalogTable( - identifier = table, - tableType = tableType, - storage = storage.copy(locationUri = customLocation), - schema = schema.getOrElse(new StructType), - provider = Some(provider), - partitionColumnNames = partitionColumnNames, - bucketSpec = bucketSpec, - properties = properties, - comment = Option(ctx.comment).map(string)) - - // Determine the storage mode. - val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists - - if (ctx.query != null) { - // Get the backing query. - val query = plan(ctx.query) - - if (temp) { - operationNotAllowed("CREATE TEMPORARY TABLE ... USING ... AS query", ctx) + if (external) { + operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } - // Don't allow explicit specification of schema for CTAS - if (schema.nonEmpty) { - operationNotAllowed( - "Schema may not be specified in a Create Table As Select (CTAS) statement", - ctx) - } - CreateTable(tableDesc, mode, Some(query)) - } else { - if (temp) { - if (ifNotExists) { - operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) - } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) - logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + - "CREATE TEMPORARY VIEW ... USING ... instead") + if (ifNotExists) { // Unlike CREATE TEMPORARY VIEW USING, CREATE TEMPORARY TABLE USING does not support // IF NOT EXISTS. Users are not allowed to replace the existing temp table. - CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) - } else { - CreateTable(tableDesc, mode, None) + operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) } + + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) + val provider = ctx.tableProvider.qualifiedName.getText + val schema = Option(ctx.colTypeList()).map(createSchema) + + logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + + "CREATE TEMPORARY VIEW ... USING ... instead") + + CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) } } @@ -562,77 +480,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "MSCK REPAIR TABLE") } - /** - * Convert a table property list into a key-value map. - * This should be called through [[visitPropertyKeyValues]] or [[visitPropertyKeys]]. - */ - override def visitTablePropertyList( - ctx: TablePropertyListContext): Map[String, String] = withOrigin(ctx) { - val properties = ctx.tableProperty.asScala.map { property => - val key = visitTablePropertyKey(property.key) - val value = visitTablePropertyValue(property.value) - key -> value - } - // Check for duplicate property names. - checkDuplicateKeys(properties, ctx) - properties.toMap - } - - /** - * Parse a key-value map from a [[TablePropertyListContext]], assuming all values are specified. - */ - private def visitPropertyKeyValues(ctx: TablePropertyListContext): Map[String, String] = { - val props = visitTablePropertyList(ctx) - val badKeys = props.collect { case (key, null) => key } - if (badKeys.nonEmpty) { - operationNotAllowed( - s"Values must be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) - } - props - } - - /** - * Parse a list of keys from a [[TablePropertyListContext]], assuming no values are specified. - */ - private def visitPropertyKeys(ctx: TablePropertyListContext): Seq[String] = { - val props = visitTablePropertyList(ctx) - val badKeys = props.filter { case (_, v) => v != null }.keys - if (badKeys.nonEmpty) { - operationNotAllowed( - s"Values should not be specified for key(s): ${badKeys.mkString("[", ",", "]")}", ctx) - } - props.keys.toSeq - } - - /** - * A table property key can either be String or a collection of dot separated elements. This - * function extracts the property key based on whether its a string literal or a table property - * identifier. - */ - override def visitTablePropertyKey(key: TablePropertyKeyContext): String = { - if (key.STRING != null) { - string(key.STRING) - } else { - key.getText - } - } - - /** - * A table property value can be String, Integer, Boolean or Decimal. This function extracts - * the property value based on whether its a string, integer, boolean or decimal literal. - */ - override def visitTablePropertyValue(value: TablePropertyValueContext): String = { - if (value == null) { - null - } else if (value.STRING != null) { - string(value.STRING) - } else if (value.booleanValue != null) { - value.getText.toLowerCase(Locale.ROOT) - } else { - value.getText - } - } - /** * Create a [[CreateDatabaseCommand]] command. * @@ -1006,34 +853,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { newColumn = visitColType(ctx.colType)) } - /** - * Create location string. - */ - override def visitLocationSpec(ctx: LocationSpecContext): String = withOrigin(ctx) { - string(ctx.STRING) - } - - /** - * Create a [[BucketSpec]]. - */ - override def visitBucketSpec(ctx: BucketSpecContext): BucketSpec = withOrigin(ctx) { - BucketSpec( - ctx.INTEGER_VALUE.getText.toInt, - visitIdentifierList(ctx.identifierList), - Option(ctx.orderedIdentifierList) - .toSeq - .flatMap(_.orderedIdentifier.asScala) - .map { orderedIdCtx => - Option(orderedIdCtx.ordering).map(_.getText).foreach { dir => - if (dir.toLowerCase(Locale.ROOT) != "asc") { - operationNotAllowed(s"Column ordering must be ASC, was '$dir'", ctx) - } - } - - orderedIdCtx.identifier.getText - }) - } - /** * Convert a nested constants list into a sequence of string sequences. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala new file mode 100644 index 0000000000000..9fd44ea4e6379 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -0,0 +1,112 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources + +import java.util.Locale + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.CastSupport +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.v2.TableProvider +import org.apache.spark.sql.types.StructType + +case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { + case CreateTableStatement( + table, schema, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, + location, comment, ifNotExists) => + + val tableDesc = buildCatalogTable(table, schema, partitionCols, bucketSpec, properties, + provider, options, location, comment, ifNotExists) + val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + + CreateTable(tableDesc, mode, None) + + case CreateTableAsSelectStatement( + table, query, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, + location, comment, ifNotExists) => + + val tableDesc = buildCatalogTable(table, new StructType, partitionCols, bucketSpec, + properties, provider, options, location, comment, ifNotExists) + val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists + + CreateTable(tableDesc, mode, Some(query)) + } + + object V1WriteProvider { + private val v1WriteOverrideSet = + conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",").toSet + + def unapply(provider: String): Option[String] = { + if (v1WriteOverrideSet.contains(provider.toLowerCase(Locale.ROOT))) { + Some(provider) + } else { + lazy val providerClass = DataSource.lookupDataSource(provider, conf) + provider match { + case _ if classOf[TableProvider].isAssignableFrom(providerClass) => + None + case _ => + Some(provider) + } + } + } + } + + private def buildCatalogTable( + table: TableIdentifier, + schema: StructType, + partitionColumnNames: Seq[String], + bucketSpec: Option[BucketSpec], + properties: Map[String, String], + provider: String, + options: Map[String, String], + location: Option[String], + comment: Option[String], + ifNotExists: Boolean): CatalogTable = { + + val storage = DataSource.buildStorageFormatFromOptions(options) + if (location.isDefined && storage.locationUri.isDefined) { + throw new AnalysisException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.") + } + val customLocation = storage.locationUri.orElse(location.map(CatalogUtils.stringToURI)) + + val tableType = if (customLocation.isDefined) { + CatalogTableType.EXTERNAL + } else { + CatalogTableType.MANAGED + } + + CatalogTable( + identifier = table, + tableType = tableType, + storage = storage.copy(locationUri = customLocation), + schema = schema, + provider = Some(provider), + partitionColumnNames = partitionColumnNames, + bucketSpec = bucketSpec, + properties = properties, + comment = comment) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index f05aa5113e03a..d5543e8a31aad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -161,6 +161,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: + DataSourceResolution(conf) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index 425a96b871ad2..be3d0794d4036 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -215,19 +215,6 @@ class SparkSqlParserSuite extends AnalysisTest { "no viable alternative at input") } - test("create table using - schema") { - assertEqual("CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet", - createTableUsing( - table = "my_tab", - schema = (new StructType) - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType) - ) - ) - intercept("CREATE TABLE my_tab(a: INT COMMENT 'test', b: STRING) USING parquet", - "no viable alternative at input") - } - test("create view as insert into table") { // Single insert query intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index e0ccae15f1d05..d430eeb294e13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -415,173 +415,28 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(ct.tableDesc.storage.locationUri == Some(new URI("/something/anything"))) } - test("create table - with partitioned by") { - val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + - "USING parquet PARTITIONED BY (a)" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType() - .add("a", IntegerType, nullable = true, "test") - .add("b", StringType), - provider = Some("parquet"), - partitionColumnNames = Seq("a") - ) - - parser.parsePlan(query) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $query") - } - } - - test("create table - with bucket") { - val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + - "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b"))) - ) - - parser.parsePlan(query) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $query") - } - } - - test("create table - with comment") { - val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - comment = Some("abc")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") + test("Duplicate clauses - create hive table") { + def createTableHeader(duplicateClause: String): String = { + s"CREATE TABLE my_tab(a INT, b STRING) STORED AS parquet $duplicateClause $duplicateClause" } - } - - test("create table - with table properties") { - val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet"), - properties = Map("test" -> "test")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - - test("Duplicate clauses - create table") { - def createTableHeader(duplicateClause: String, isNative: Boolean): String = { - val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" - s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" - } - - Seq(true, false).foreach { isNative => - intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), - "Found duplicate clauses: TBLPROPERTIES") - intercept(createTableHeader("LOCATION '/tmp/file'", isNative), - "Found duplicate clauses: LOCATION") - intercept(createTableHeader("COMMENT 'a table'", isNative), - "Found duplicate clauses: COMMENT") - intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), - "Found duplicate clauses: CLUSTERED BY") - } - - // Only for native data source tables - intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')"), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'"), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'"), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS"), + "Found duplicate clauses: CLUSTERED BY") + intercept(createTableHeader("PARTITIONED BY (k int)"), "Found duplicate clauses: PARTITIONED BY") - - // Only for Hive serde tables - intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), - "Found duplicate clauses: PARTITIONED BY") - intercept(createTableHeader("STORED AS parquet", isNative = false), + intercept(createTableHeader("STORED AS parquet"), "Found duplicate clauses: STORED AS/BY") intercept( - createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'"), "Found duplicate clauses: ROW FORMAT") } - test("create table - with location") { - val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("my_tab"), - tableType = CatalogTableType.EXTERNAL, - storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), - schema = new StructType().add("a", IntegerType).add("b", StringType), - provider = Some("parquet")) - - parser.parsePlan(v1) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $v1") - } - - val v2 = - """ - |CREATE TABLE my_tab(a INT, b STRING) - |USING parquet - |OPTIONS (path '/tmp/file') - |LOCATION '/tmp/file' - """.stripMargin - val e = intercept[ParseException] { - parser.parsePlan(v2) - } - assert(e.message.contains("you can only specify one of them.")) - } - - test("create table - byte length literal table name") { - val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("2g", Some("1m")), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty, - schema = new StructType().add("a", IntegerType), - provider = Some("parquet")) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - test("insert overwrite directory") { val v1 = "INSERT OVERWRITE DIRECTORY '/tmp/file' USING parquet SELECT 1 as a" parser.parsePlan(v1) match { @@ -1165,84 +1020,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { comparePlans(parsed, expected) } - test("support for other types in OPTIONS") { - val sql = - """ - |CREATE TABLE table_name USING json - |OPTIONS (a 1, b 0.1, c TRUE) - """.stripMargin - - val expectedTableDesc = CatalogTable( - identifier = TableIdentifier("table_name"), - tableType = CatalogTableType.MANAGED, - storage = CatalogStorageFormat.empty.copy( - properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true") - ), - schema = new StructType, - provider = Some("json") - ) - - parser.parsePlan(sql) match { - case CreateTable(tableDesc, _, None) => - assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) - case other => - fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + - s"got ${other.getClass.getName}: $sql") - } - } - - test("Test CTAS against data source tables") { - val s1 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |COMMENT 'This is the staging page view table' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s2 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |LOCATION '/user/external/page_view' - |COMMENT 'This is the staging page view table' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - val s3 = - """ - |CREATE TABLE IF NOT EXISTS mydb.page_view - |USING parquet - |COMMENT 'This is the staging page view table' - |LOCATION '/user/external/page_view' - |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src - """.stripMargin - - checkParsing(s1) - checkParsing(s2) - checkParsing(s3) - - def checkParsing(sql: String): Unit = { - val (desc, exists) = extractTableDesc(sql) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.provider == Some("parquet")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) - } - } - test("Test CTAS #1") { val s1 = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala new file mode 100644 index 0000000000000..89c5df0900b61 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -0,0 +1,257 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.command + +import java.net.URI + +import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.analysis.AnalysisTest +import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType} + +class PlanResolutionSuite extends AnalysisTest { + import CatalystSqlParser._ + + def parseAndResolve(query: String): LogicalPlan = { + DataSourceResolution(conf).apply(parsePlan(query)) + } + + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { + parseAndResolve(sql).collect { + case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) + }.head + } + + test("create table - with partitioned by") { + val query = "CREATE TABLE my_tab(a INT comment 'test', b STRING) " + + "USING parquet PARTITIONED BY (a)" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType() + .add("a", IntegerType, nullable = true, "test") + .add("b", StringType), + provider = Some("parquet"), + partitionColumnNames = Seq("a") + ) + + parseAndResolve(query) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with bucket") { + val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + bucketSpec = Some(BucketSpec(5, Seq("a"), Seq("b"))) + ) + + parseAndResolve(query) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $query") + } + } + + test("create table - with comment") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet COMMENT 'abc'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + comment = Some("abc")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with table properties") { + val sql = "CREATE TABLE my_tab(a INT, b STRING) USING parquet TBLPROPERTIES('test' = 'test')" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet"), + properties = Map("test" -> "test")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("create table - with location") { + val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("my_tab"), + tableType = CatalogTableType.EXTERNAL, + storage = CatalogStorageFormat.empty.copy(locationUri = Some(new URI("/tmp/file"))), + schema = new StructType().add("a", IntegerType).add("b", StringType), + provider = Some("parquet")) + + parseAndResolve(v1) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $v1") + } + + val v2 = + """ + |CREATE TABLE my_tab(a INT, b STRING) + |USING parquet + |OPTIONS (path '/tmp/file') + |LOCATION '/tmp/file' + """.stripMargin + val e = intercept[AnalysisException] { + parseAndResolve(v2) + } + assert(e.message.contains("you can only specify one of them.")) + } + + test("create table - byte length literal table name") { + val sql = "CREATE TABLE 1m.2g(a INT) USING parquet" + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("2g", Some("1m")), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty, + schema = new StructType().add("a", IntegerType), + provider = Some("parquet")) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("support for other types in OPTIONS") { + val sql = + """ + |CREATE TABLE table_name USING json + |OPTIONS (a 1, b 0.1, c TRUE) + """.stripMargin + + val expectedTableDesc = CatalogTable( + identifier = TableIdentifier("table_name"), + tableType = CatalogTableType.MANAGED, + storage = CatalogStorageFormat.empty.copy( + properties = Map("a" -> "1", "b" -> "0.1", "c" -> "true") + ), + schema = new StructType, + provider = Some("json") + ) + + parseAndResolve(sql) match { + case CreateTable(tableDesc, _, None) => + assert(tableDesc == expectedTableDesc.copy(createTime = tableDesc.createTime)) + case other => + fail(s"Expected to parse ${classOf[CreateTableCommand].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database.contains("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri.contains(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment.contains("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider.contains("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 68f4b2ddbac0b..877a0dadf0b03 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -73,6 +73,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: + DataSourceResolution(conf) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = From 0f9ac2a27c22aacc766d9586ec79d7c7f0c115cd Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 10 Apr 2019 14:30:39 +0800 Subject: [PATCH 27/70] [SPARK-27181][SQL] Add public transform API ## What changes were proposed in this pull request? This adds a public Expression API that can be used to pass partition transformations to data sources. ## How was this patch tested? Existing tests to validate no regressions. Added transform cases to DDL suite and v1 conversions suite. Closes #24117 from rdblue/add-public-transform-api. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 17 +- .../catalog/v2/expressions/Expression.java | 31 +++ .../catalog/v2/expressions/Expressions.java | 162 ++++++++++++++ .../sql/catalog/v2/expressions/Literal.java | 42 ++++ .../v2/expressions/NamedReference.java | 33 +++ .../sql/catalog/v2/expressions/Transform.java | 44 ++++ .../catalog/v2/expressions/expressions.scala | 203 ++++++++++++++++++ .../sql/catalyst/parser/AstBuilder.scala | 101 ++++++++- .../logical/sql/CreateTableStatement.scala | 5 +- .../sql/catalyst/parser/DDLParserSuite.scala | 52 ++++- .../datasources/DataSourceResolution.scala | 7 +- .../command/PlanResolutionSuite.scala | 20 ++ 12 files changed, 705 insertions(+), 12 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 0963ecf00c3d1..d9caea170bc47 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -89,7 +89,7 @@ statement | DROP database (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider ((OPTIONS options=tablePropertyList) | - (PARTITIONED BY partitionColumnNames=identifierList) | + (PARTITIONED BY partitioning=transformList) | bucketSpec | locationSpec | (COMMENT comment=STRING) | @@ -578,6 +578,21 @@ namedExpressionSeq : namedExpression (',' namedExpression)* ; +transformList + : '(' transforms+=transform (',' transforms+=transform)* ')' + ; + +transform + : qualifiedName #identityTransform + | transformName=identifier + '(' argument+=transformArgument (',' argument+=transformArgument)* ')' #applyTransform + ; + +transformArgument + : qualifiedName + | constant + ; + expression : booleanExpression ; diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java new file mode 100644 index 0000000000000..1e2aca9556df4 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expression.java @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Base class of the public logical expression API. + */ +@Experimental +public interface Expression { + /** + * Format the expression as a human readable SQL-like string. + */ + String describe(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java new file mode 100644 index 0000000000000..009e89bd4eb60 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java @@ -0,0 +1,162 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; +import scala.collection.JavaConverters; + +import java.util.Arrays; + +/** + * Helper methods to create logical transforms to pass into Spark. + */ +@Experimental +public class Expressions { + private Expressions() { + } + + /** + * Create a logical transform for applying a named transform. + *

+ * This transform can represent applying any named transform. + * + * @param name the transform name + * @param args expression arguments to the transform + * @return a logical transform + */ + public Transform apply(String name, Expression... args) { + return LogicalExpressions.apply(name, + JavaConverters.asScalaBuffer(Arrays.asList(args)).toSeq()); + } + + /** + * Create a named reference expression for a column. + * + * @param name a column name + * @return a named reference for the column + */ + public NamedReference column(String name) { + return LogicalExpressions.reference(name); + } + + /** + * Create a literal from a value. + *

+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API + * for the literal's {@link DataType SQL data type}. + * + * @param value a value + * @param the JVM type of the value + * @return a literal expression for the value + */ + public Literal literal(T value) { + return LogicalExpressions.literal(value); + } + + /** + * Create a bucket transform for one or more columns. + *

+ * This transform represents a logical mapping from a value to a bucket id in [0, numBuckets) + * based on a hash of the value. + *

+ * The name reported by transforms created with this method is "bucket". + * + * @param numBuckets the number of output buckets + * @param columns input columns for the bucket transform + * @return a logical bucket transform with name "bucket" + */ + public Transform bucket(int numBuckets, String... columns) { + return LogicalExpressions.bucket(numBuckets, + JavaConverters.asScalaBuffer(Arrays.asList(columns)).toSeq()); + } + + /** + * Create an identity transform for a column. + *

+ * This transform represents a logical mapping from a value to itself. + *

+ * The name reported by transforms created with this method is "identity". + * + * @param column an input column + * @return a logical identity transform with name "identity" + */ + public Transform identity(String column) { + return LogicalExpressions.identity(column); + } + + /** + * Create a yearly transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a year, such as 2018. + *

+ * The name reported by transforms created with this method is "years". + * + * @param column an input timestamp or date column + * @return a logical yearly transform with name "years" + */ + public Transform years(String column) { + return LogicalExpressions.years(column); + } + + /** + * Create a monthly transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a month, such as + * 2018-05. + *

+ * The name reported by transforms created with this method is "months". + * + * @param column an input timestamp or date column + * @return a logical monthly transform with name "months" + */ + public Transform months(String column) { + return LogicalExpressions.months(column); + } + + /** + * Create a daily transform for a timestamp or date column. + *

+ * This transform represents a logical mapping from a timestamp or date to a date, such as + * 2018-05-13. + *

+ * The name reported by transforms created with this method is "days". + * + * @param column an input timestamp or date column + * @return a logical daily transform with name "days" + */ + public Transform days(String column) { + return LogicalExpressions.days(column); + } + + /** + * Create an hourly transform for a timestamp column. + *

+ * This transform represents a logical mapping from a timestamp to a date and hour, such as + * 2018-05-13, hour 19. + *

+ * The name reported by transforms created with this method is "hours". + * + * @param column an input timestamp column + * @return a logical hourly transform with name "hours" + */ + public Transform hours(String column) { + return LogicalExpressions.hours(column); + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java new file mode 100644 index 0000000000000..e41bcf9000c52 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Literal.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; + +/** + * Represents a constant literal value in the public expression API. + *

+ * The JVM type of the value held by a literal must be the type used by Spark's InternalRow API for + * the literal's {@link DataType SQL data type}. + * + * @param the JVM type of a value held by the literal + */ +@Experimental +public interface Literal extends Expression { + /** + * Returns the literal value. + */ + T value(); + + /** + * Returns the SQL data type of the literal. + */ + DataType dataType(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java new file mode 100644 index 0000000000000..c71ffbe70651f --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/NamedReference.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Represents a field or column reference in the public logical expression API. + */ +@Experimental +public interface NamedReference extends Expression { + /** + * Returns the referenced field name as an array of String parts. + *

+ * Each string in the returned array represents a field name. + */ + String[] fieldNames(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java new file mode 100644 index 0000000000000..c85e0c412f1ab --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Transform.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions; + +import org.apache.spark.annotation.Experimental; + +/** + * Represents a transform function in the public logical expression API. + *

+ * For example, the transform date(ts) is used to derive a date value from a timestamp column. The + * transform name is "date" and its argument is a reference to the "ts" column. + */ +@Experimental +public interface Transform extends Expression { + /** + * Returns the transform function name. + */ + String name(); + + /** + * Returns all field references in the transform arguments. + */ + NamedReference[] references(); + + /** + * Returns the arguments passed to the transform function. + */ + Expression[] arguments(); +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala new file mode 100644 index 0000000000000..813d88255c6a2 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -0,0 +1,203 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2.expressions + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DataType, IntegerType, StringType} + +/** + * Helper methods for working with the logical expressions API. + * + * Factory methods can be used when referencing the logical expression nodes is ambiguous because + * logical and internal expressions are used. + */ +private[sql] object LogicalExpressions { + // a generic parser that is only used for parsing multi-part field names. + // because this is only used for field names, the SQL conf passed in does not matter. + private lazy val parser = new CatalystSqlParser(SQLConf.get) + + def fromPartitionColumns(columns: String*): Array[IdentityTransform] = + columns.map(identity).toArray + + def fromBucketSpec(spec: BucketSpec): BucketTransform = { + if (spec.sortColumnNames.nonEmpty) { + throw new AnalysisException( + s"Cannot convert bucketing with sort columns to a transform: $spec") + } + + bucket(spec.numBuckets, spec.bucketColumnNames: _*) + } + + implicit class TransformHelper(transforms: Seq[Transform]) { + def asPartitionColumns: Seq[String] = { + val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform]) + + if (nonIdTransforms.nonEmpty) { + throw new AnalysisException("Transforms cannot be converted to partition columns: " + + nonIdTransforms.map(_.describe).mkString(", ")) + } + + idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref => + val parts = ref.fieldNames + if (parts.size > 1) { + throw new AnalysisException(s"Cannot partition by nested column: $ref") + } else { + parts(0) + } + } + } + } + + def literal[T](value: T): LiteralValue[T] = { + val internalLit = catalyst.expressions.Literal(value) + literal(value, internalLit.dataType) + } + + def literal[T](value: T, dataType: DataType): LiteralValue[T] = LiteralValue(value, dataType) + + def reference(name: String): NamedReference = + FieldReference(parser.parseMultipartIdentifier(name)) + + def apply(name: String, arguments: Expression*): Transform = ApplyTransform(name, arguments) + + def bucket(numBuckets: Int, columns: String*): BucketTransform = + BucketTransform(literal(numBuckets, IntegerType), columns.map(reference)) + + def identity(column: String): IdentityTransform = IdentityTransform(reference(column)) + + def years(column: String): YearsTransform = YearsTransform(reference(column)) + + def months(column: String): MonthsTransform = MonthsTransform(reference(column)) + + def days(column: String): DaysTransform = DaysTransform(reference(column)) + + def hours(column: String): HoursTransform = HoursTransform(reference(column)) +} + +/** + * Base class for simple transforms of a single column. + */ +private[sql] abstract class SingleColumnTransform(ref: NamedReference) extends Transform { + + def reference: NamedReference = ref + + override def references: Array[NamedReference] = Array(ref) + + override def arguments: Array[Expression] = Array(ref) + + override def describe: String = name + "(" + reference.describe + ")" + + override def toString: String = describe +} + +private[sql] final case class BucketTransform( + numBuckets: Literal[Int], + columns: Seq[NamedReference]) extends Transform { + + override val name: String = "bucket" + + override def references: Array[NamedReference] = { + arguments + .filter(_.isInstanceOf[NamedReference]) + .map(_.asInstanceOf[NamedReference]) + } + + override def arguments: Array[Expression] = numBuckets +: columns.toArray + + override def describe: String = s"bucket(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe +} + +private[sql] final case class ApplyTransform( + name: String, + args: Seq[Expression]) extends Transform { + + override def arguments: Array[Expression] = args.toArray + + override def references: Array[NamedReference] = { + arguments + .filter(_.isInstanceOf[NamedReference]) + .map(_.asInstanceOf[NamedReference]) + } + + override def describe: String = s"$name(${arguments.map(_.describe).mkString(", ")})" + + override def toString: String = describe +} + +private[sql] final case class IdentityTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "identity" + override def describe: String = ref.describe +} + +private[sql] final case class YearsTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "years" +} + +private[sql] final case class MonthsTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "months" +} + +private[sql] final case class DaysTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "days" +} + +private[sql] final case class HoursTransform( + ref: NamedReference) extends SingleColumnTransform(ref) { + override val name: String = "hours" +} + +private[sql] final case class LiteralValue[T](value: T, dataType: DataType) extends Literal[T] { + override def describe: String = { + if (dataType.isInstanceOf[StringType]) { + s"'$value'" + } else { + s"$value" + } + } + override def toString: String = describe +} + +private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { + override def fieldNames: Array[String] = parts.toArray + override def describe: String = parts.map(quote).mkString(".") + override def toString: String = describe + + private def quote(part: String): String = { + if (part.contains(".") || part.contains("`")) { + s"`${part.replace("`", "``")}`" + } else { + part + } + } +} + +private[sql] object FieldReference { + def apply(column: String): NamedReference = { + LogicalExpressions.reference(column) + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 3732b437c8511..b18959f2e972a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -29,6 +29,8 @@ import org.antlr.v4.runtime.tree.{ParseTree, RuleNode, TerminalNode} import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2 +import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, Transform, YearsTransform} import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat} @@ -1988,6 +1990,95 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) } + /** + * Parse a list of transforms. + */ + override def visitTransformList(ctx: TransformListContext): Seq[Transform] = withOrigin(ctx) { + def getFieldReference( + ctx: ApplyTransformContext, + arg: v2.expressions.Expression): FieldReference = { + lazy val name: String = ctx.identifier.getText + arg match { + case ref: FieldReference => + ref + case nonRef => + throw new ParseException( + s"Expected a column reference for transform $name: ${nonRef.describe}", ctx) + } + } + + def getSingleFieldReference( + ctx: ApplyTransformContext, + arguments: Seq[v2.expressions.Expression]): FieldReference = { + lazy val name: String = ctx.identifier.getText + if (arguments.size > 1) { + throw new ParseException(s"Too many arguments for transform $name", ctx) + } else if (arguments.isEmpty) { + throw new ParseException(s"Not enough arguments for transform $name", ctx) + } else { + getFieldReference(ctx, arguments.head) + } + } + + ctx.transforms.asScala.map { + case identityCtx: IdentityTransformContext => + IdentityTransform(FieldReference( + identityCtx.qualifiedName.identifier.asScala.map(_.getText))) + + case applyCtx: ApplyTransformContext => + val arguments = applyCtx.argument.asScala.map(visitTransformArgument) + + applyCtx.identifier.getText match { + case "bucket" => + val numBuckets: Int = arguments.head match { + case LiteralValue(shortValue, ShortType) => + shortValue.asInstanceOf[Short].toInt + case LiteralValue(intValue, IntegerType) => + intValue.asInstanceOf[Int] + case LiteralValue(longValue, LongType) => + longValue.asInstanceOf[Long].toInt + case lit => + throw new ParseException(s"Invalid number of buckets: ${lit.describe}", applyCtx) + } + + val fields = arguments.tail.map(arg => getFieldReference(applyCtx, arg)) + + BucketTransform(LiteralValue(numBuckets, IntegerType), fields) + + case "years" => + YearsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "months" => + MonthsTransform(getSingleFieldReference(applyCtx, arguments)) + + case "days" => + DaysTransform(getSingleFieldReference(applyCtx, arguments)) + + case "hours" => + HoursTransform(getSingleFieldReference(applyCtx, arguments)) + + case name => + ApplyTransform(name, arguments) + } + } + } + + /** + * Parse an argument to a transform. An argument may be a field reference (qualified name) or + * a value literal. + */ + override def visitTransformArgument(ctx: TransformArgumentContext): v2.expressions.Expression = { + withOrigin(ctx) { + val reference = Option(ctx.qualifiedName) + .map(nameCtx => FieldReference(nameCtx.identifier.asScala.map(_.getText))) + val literal = Option(ctx.constant) + .map(typedVisit[Literal]) + .map(lit => LiteralValue(lit.value, lit.dataType)) + reference.orElse(literal) + .getOrElse(throw new ParseException(s"Invalid transform argument", ctx)) + } + } + /** * Create a table, returning a [[CreateTableStatement]] logical plan. * @@ -2000,7 +2091,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging * * create_table_clauses (order insensitive): * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] + * [PARTITIONED BY (col_name, transform(col_name), transform(constant, col_name), ...)] * [CLUSTERED BY (col_name, col_name, ...) * [SORTED BY (col_name [ASC|DESC], ...)] * INTO num_buckets BUCKETS @@ -2024,8 +2115,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) val schema = Option(ctx.colTypeList()).map(createSchema) - val partitionCols: Seq[String] = - Option(ctx.partitionColumnNames).map(visitIdentifierList).getOrElse(Nil) + val partitioning: Seq[Transform] = + Option(ctx.partitioning).map(visitTransformList).getOrElse(Nil) val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) @@ -2045,7 +2136,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging case Some(query) => CreateTableAsSelectStatement( - table, query, partitionCols, bucketSpec, properties, provider, options, location, comment, + table, query, partitioning, bucketSpec, properties, provider, options, location, comment, ifNotExists = ifNotExists) case None if temp => @@ -2054,7 +2145,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging operationNotAllowed("CREATE TEMPORARY TABLE IF NOT EXISTS", ctx) case _ => - CreateTableStatement(table, schema.getOrElse(new StructType), partitionCols, bucketSpec, + CreateTableStatement(table, schema.getOrElse(new StructType), partitioning, bucketSpec, properties, provider, options, location, comment, ifNotExists = ifNotExists) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala index c734968e838db..ed1b3e3778c7f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.sql +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.Attribute @@ -31,7 +32,7 @@ import org.apache.spark.sql.types.StructType case class CreateTableStatement( table: TableIdentifier, tableSchema: StructType, - partitioning: Seq[String], + partitioning: Seq[Transform], bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: String, @@ -51,7 +52,7 @@ case class CreateTableStatement( case class CreateTableAsSelectStatement( table: TableIdentifier, asSelect: LogicalPlan, - partitioning: Seq[String], + partitioning: Seq[Transform], bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index dae8f582c7716..98388a74cd29d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -17,11 +17,13 @@ package org.apache.spark.sql.catalyst.parser +import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType} +import org.apache.spark.unsafe.types.UTF8String class DDLParserSuite extends AnalysisTest { import CatalystSqlParser._ @@ -92,7 +94,7 @@ class DDLParserSuite extends AnalysisTest { assert(create.tableSchema == new StructType() .add("a", IntegerType, nullable = true, "test") .add("b", StringType)) - assert(create.partitioning == Seq("a")) + assert(create.partitioning == Seq(IdentityTransform(FieldReference("a")))) assert(create.bucketSpec.isEmpty) assert(create.properties.isEmpty) assert(create.provider == "parquet") @@ -107,6 +109,52 @@ class DDLParserSuite extends AnalysisTest { } } + test("create table - partitioned by transforms") { + val sql = + """ + |CREATE TABLE my_tab (a INT, b STRING, ts TIMESTAMP) USING parquet + |PARTITIONED BY ( + | a, + | bucket(16, b), + | years(ts), + | months(ts), + | days(ts), + | hours(ts), + | foo(a, "bar", 34)) + """.stripMargin + + parsePlan(sql) match { + case create: CreateTableStatement => + assert(create.table == TableIdentifier("my_tab")) + assert(create.tableSchema == new StructType() + .add("a", IntegerType) + .add("b", StringType) + .add("ts", TimestampType)) + assert(create.partitioning == Seq( + IdentityTransform(FieldReference("a")), + BucketTransform(LiteralValue(16, IntegerType), Seq(FieldReference("b"))), + YearsTransform(FieldReference("ts")), + MonthsTransform(FieldReference("ts")), + DaysTransform(FieldReference("ts")), + HoursTransform(FieldReference("ts")), + ApplyTransform("foo", Seq( + FieldReference("a"), + LiteralValue(UTF8String.fromString("bar"), StringType), + LiteralValue(34, IntegerType))))) + assert(create.bucketSpec.isEmpty) + assert(create.properties.isEmpty) + assert(create.provider == "parquet") + assert(create.options.isEmpty) + assert(create.location.isEmpty) + assert(create.comment.isEmpty) + assert(!create.ifNotExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableStatement].getClass.getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("create table - with bucket") { val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 9fd44ea4e6379..f503ff03b971c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} @@ -31,6 +32,8 @@ import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.StructType case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { + import org.apache.spark.sql.catalog.v2.expressions.LogicalExpressions.TransformHelper + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( table, schema, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, @@ -75,7 +78,7 @@ case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with Ca private def buildCatalogTable( table: TableIdentifier, schema: StructType, - partitionColumnNames: Seq[String], + partitioning: Seq[Transform], bucketSpec: Option[BucketSpec], properties: Map[String, String], provider: String, @@ -104,7 +107,7 @@ case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with Ca storage = storage.copy(locationUri = customLocation), schema = schema, provider = Some(provider), - partitionColumnNames = partitionColumnNames, + partitionColumnNames = partitioning.asPartitionColumns, bucketSpec = bucketSpec, properties = properties, comment = comment) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 89c5df0900b61..7fae54bb95ed1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -65,6 +65,26 @@ class PlanResolutionSuite extends AnalysisTest { } } + test("create table - partitioned by transforms") { + val transforms = Seq( + "bucket(16, b)", "years(ts)", "months(ts)", "days(ts)", "hours(ts)", "foo(a, 'bar', 34)", + "bucket(32, b), days(ts)") + transforms.foreach { transform => + val query = + s""" + |CREATE TABLE my_tab(a INT, b STRING) USING parquet + |PARTITIONED BY ($transform) + """.stripMargin + + val ae = intercept[AnalysisException] { + parseAndResolve(query) + } + + assert(ae.message + .contains(s"Transforms cannot be converted to partition columns: $transform")) + } + } + test("create table - with bucket") { val query = "CREATE TABLE my_tab(a INT, b STRING) USING parquet " + "CLUSTERED BY (a) SORTED BY (b) INTO 5 BUCKETS" From d70253ea99455752d05c65849805de6c682da7dc Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 8 May 2019 10:31:06 +0800 Subject: [PATCH 28/70] [SPARK-24252][SQL] Add TableCatalog API ## What changes were proposed in this pull request? This adds the TableCatalog API proposed in the [Table Metadata API SPIP](https://docs.google.com/document/d/1zLFiA1VuaWeVxeTDXNg8bL6GP3BVoOZBkewFtEnjEoo/edit#heading=h.m45webtwxf2d). For `TableCatalog` to use `Table`, it needed to be moved into the catalyst module where the v2 catalog API is located. This also required moving `TableCapability`. Most of the files touched by this PR are import changes needed by this move. ## How was this patch tested? This adds a test implementation and contract tests. Closes #24246 from rdblue/SPARK-24252-add-table-catalog-api. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../spark/sql/catalog/v2/IdentifierImpl.java | 25 + .../spark/sql/catalog/v2/TableCatalog.java | 137 ++++ .../spark/sql/catalog/v2/TableChange.java | 366 ++++++++++ .../catalog/v2/expressions/Expressions.java | 18 +- .../apache/spark/sql/sources/v2/Table.java | 21 + .../spark/sql/sources/v2/TableCapability.java | 6 +- .../sql/catalog/v2/CatalogV2Implicits.scala | 98 +++ .../catalog/v2/expressions/expressions.scala | 45 +- .../analysis/AlreadyExistException.scala | 23 +- .../analysis/NoSuchItemException.scala | 22 +- .../sql/catalog/v2/TableCatalogSuite.scala | 657 ++++++++++++++++++ .../sql/catalog/v2/TestTableCatalog.scala | 220 ++++++ .../datasources/DataSourceResolution.scala | 2 +- .../execution/datasources/v2/FileTable.scala | 11 +- .../v2/JavaPartitionAwareDataSource.java | 7 + 15 files changed, 1594 insertions(+), 64 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/Table.java (72%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java (92%) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java index 8874faa71b5bb..cd131432008a6 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java @@ -17,8 +17,12 @@ package org.apache.spark.sql.catalog.v2; +import com.google.common.base.Preconditions; import org.apache.spark.annotation.Experimental; +import java.util.Arrays; +import java.util.Objects; + /** * An {@link Identifier} implementation. */ @@ -29,6 +33,8 @@ class IdentifierImpl implements Identifier { private String name; IdentifierImpl(String[] namespace, String name) { + Preconditions.checkNotNull(namespace, "Identifier namespace cannot be null"); + Preconditions.checkNotNull(name, "Identifier name cannot be null"); this.namespace = namespace; this.name = name; } @@ -42,4 +48,23 @@ public String[] namespace() { public String name() { return name; } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + IdentifierImpl that = (IdentifierImpl) o; + return Arrays.equals(namespace, that.namespace) && name.equals(that.name); + } + + @Override + public int hashCode() { + return Objects.hash(Arrays.hashCode(namespace), name); + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java new file mode 100644 index 0000000000000..681629d2d5405 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableCatalog.java @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.catalog.v2.expressions.Transform; +import org.apache.spark.sql.catalyst.analysis.NoSuchNamespaceException; +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException; +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException; +import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.types.StructType; + +import java.util.Map; + +/** + * Catalog methods for working with Tables. + *

+ * TableCatalog implementations may be case sensitive or case insensitive. Spark will pass + * {@link Identifier table identifiers} without modification. Field names passed to + * {@link #alterTable(Identifier, TableChange...)} will be normalized to match the case used in the + * table schema when updating, renaming, or dropping existing columns when catalyst analysis is case + * insensitive. + */ +public interface TableCatalog extends CatalogPlugin { + /** + * List the tables in a namespace from the catalog. + *

+ * If the catalog supports views, this must return identifiers for only tables and not views. + * + * @param namespace a multi-part namespace + * @return an array of Identifiers for tables + * @throws NoSuchNamespaceException If the namespace does not exist (optional). + */ + Identifier[] listTables(String[] namespace) throws NoSuchNamespaceException; + + /** + * Load table metadata by {@link Identifier identifier} from the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @return the table's metadata + * @throws NoSuchTableException If the table doesn't exist or is a view + */ + Table loadTable(Identifier ident) throws NoSuchTableException; + + /** + * Invalidate cached table metadata for an {@link Identifier identifier}. + *

+ * If the table is already loaded or cached, drop cached data. If the table does not exist or is + * not cached, do nothing. Calling this method should not query remote services. + * + * @param ident a table identifier + */ + default void invalidateTable(Identifier ident) { + } + + /** + * Test whether a table exists using an {@link Identifier identifier} from the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must return false. + * + * @param ident a table identifier + * @return true if the table exists, false otherwise + */ + default boolean tableExists(Identifier ident) { + try { + return loadTable(ident) != null; + } catch (NoSuchTableException e) { + return false; + } + } + + /** + * Create a table in the catalog. + * + * @param ident a table identifier + * @param schema the schema of the new table, as a struct type + * @param partitions transforms to use for partitioning data in the table + * @param properties a string map of table properties + * @return metadata for the new table + * @throws TableAlreadyExistsException If a table or view already exists for the identifier + * @throws UnsupportedOperationException If a requested partition transform is not supported + * @throws NoSuchNamespaceException If the identifier namespace does not exist (optional) + */ + Table createTable( + Identifier ident, + StructType schema, + Transform[] partitions, + Map properties) throws TableAlreadyExistsException, NoSuchNamespaceException; + + /** + * Apply a set of {@link TableChange changes} to a table in the catalog. + *

+ * Implementations may reject the requested changes. If any change is rejected, none of the + * changes should be applied to the table. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must throw {@link NoSuchTableException}. + * + * @param ident a table identifier + * @param changes changes to apply to the table + * @return updated metadata for the table + * @throws NoSuchTableException If the table doesn't exist or is a view + * @throws IllegalArgumentException If any change is rejected by the implementation. + */ + Table alterTable( + Identifier ident, + TableChange... changes) throws NoSuchTableException; + + /** + * Drop a table in the catalog. + *

+ * If the catalog supports views and contains a view for the identifier and not a table, this + * must not drop the view and must return false. + * + * @param ident a table identifier + * @return true if a table was deleted, false if no table exists for the identifier + */ + boolean dropTable(Identifier ident); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java new file mode 100644 index 0000000000000..9b87e676d9b2d --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/TableChange.java @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2; + +import org.apache.spark.sql.types.DataType; + +/** + * TableChange subclasses represent requested changes to a table. These are passed to + * {@link TableCatalog#alterTable}. For example, + *

+ *   import TableChange._
+ *   val catalog = Catalogs.load(name)
+ *   catalog.asTableCatalog.alterTable(ident,
+ *       addColumn("x", IntegerType),
+ *       renameColumn("a", "b"),
+ *       deleteColumn("c")
+ *     )
+ * 
+ */ +public interface TableChange { + + /** + * Create a TableChange for setting a table property. + *

+ * If the property already exists, it will be replaced with the new value. + * + * @param property the property name + * @param value the new property value + * @return a TableChange for the addition + */ + static TableChange setProperty(String property, String value) { + return new SetProperty(property, value); + } + + /** + * Create a TableChange for removing a table property. + *

+ * If the property does not exist, the change will succeed. + * + * @param property the property name + * @return a TableChange for the addition + */ + static TableChange removeProperty(String property) { + return new RemoveProperty(property); + } + + /** + * Create a TableChange for adding an optional column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @return a TableChange for the addition + */ + static TableChange addColumn(String[] fieldNames, DataType dataType) { + return new AddColumn(fieldNames, dataType, true, null); + } + + /** + * Create a TableChange for adding a column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @param isNullable whether the new column can contain null + * @return a TableChange for the addition + */ + static TableChange addColumn(String[] fieldNames, DataType dataType, boolean isNullable) { + return new AddColumn(fieldNames, dataType, isNullable, null); + } + + /** + * Create a TableChange for adding a column. + *

+ * If the field already exists, the change will result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change will + * result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the new column + * @param dataType the new column's data type + * @param isNullable whether the new column can contain null + * @param comment the new field's comment string + * @return a TableChange for the addition + */ + static TableChange addColumn( + String[] fieldNames, + DataType dataType, + boolean isNullable, + String comment) { + return new AddColumn(fieldNames, dataType, isNullable, comment); + } + + /** + * Create a TableChange for renaming a field. + *

+ * The name is used to find the field to rename. The new name will replace the leaf field name. + * For example, renameColumn(["a", "b", "c"], "x") should produce column a.b.x. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames the current field names + * @param newName the new name + * @return a TableChange for the rename + */ + static TableChange renameColumn(String[] fieldNames, String newName) { + return new RenameColumn(fieldNames, newName); + } + + /** + * Create a TableChange for updating the type of a field that is nullable. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newDataType the new data type + * @return a TableChange for the update + */ + static TableChange updateColumnType(String[] fieldNames, DataType newDataType) { + return new UpdateColumnType(fieldNames, newDataType, true); + } + + /** + * Create a TableChange for updating the type of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newDataType the new data type + * @return a TableChange for the update + */ + static TableChange updateColumnType( + String[] fieldNames, + DataType newDataType, + boolean isNullable) { + return new UpdateColumnType(fieldNames, newDataType, isNullable); + } + + /** + * Create a TableChange for updating the comment of a field. + *

+ * The name is used to find the field to update. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to update + * @param newComment the new comment + * @return a TableChange for the update + */ + static TableChange updateColumnComment(String[] fieldNames, String newComment) { + return new UpdateColumnComment(fieldNames, newComment); + } + + /** + * Create a TableChange for deleting a field. + *

+ * If the field does not exist, the change will result in an {@link IllegalArgumentException}. + * + * @param fieldNames field names of the column to delete + * @return a TableChange for the delete + */ + static TableChange deleteColumn(String[] fieldNames) { + return new DeleteColumn(fieldNames); + } + + /** + * A TableChange to set a table property. + *

+ * If the property already exists, it must be replaced with the new value. + */ + final class SetProperty implements TableChange { + private final String property; + private final String value; + + private SetProperty(String property, String value) { + this.property = property; + this.value = value; + } + + public String property() { + return property; + } + + public String value() { + return value; + } + } + + /** + * A TableChange to remove a table property. + *

+ * If the property does not exist, the change should succeed. + */ + final class RemoveProperty implements TableChange { + private final String property; + + private RemoveProperty(String property) { + this.property = property; + } + + public String property() { + return property; + } + } + + /** + * A TableChange to add a field. + *

+ * If the field already exists, the change must result in an {@link IllegalArgumentException}. + * If the new field is nested and its parent does not exist or is not a struct, the change must + * result in an {@link IllegalArgumentException}. + */ + final class AddColumn implements TableChange { + private final String[] fieldNames; + private final DataType dataType; + private final boolean isNullable; + private final String comment; + + private AddColumn(String[] fieldNames, DataType dataType, boolean isNullable, String comment) { + this.fieldNames = fieldNames; + this.dataType = dataType; + this.isNullable = isNullable; + this.comment = comment; + } + + public String[] fieldNames() { + return fieldNames; + } + + public DataType dataType() { + return dataType; + } + + public boolean isNullable() { + return isNullable; + } + + public String comment() { + return comment; + } + } + + /** + * A TableChange to rename a field. + *

+ * The name is used to find the field to rename. The new name will replace the leaf field name. + * For example, renameColumn("a.b.c", "x") should produce column a.b.x. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class RenameColumn implements TableChange { + private final String[] fieldNames; + private final String newName; + + private RenameColumn(String[] fieldNames, String newName) { + this.fieldNames = fieldNames; + this.newName = newName; + } + + public String[] fieldNames() { + return fieldNames; + } + + public String newName() { + return newName; + } + } + + /** + * A TableChange to update the type of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class UpdateColumnType implements TableChange { + private final String[] fieldNames; + private final DataType newDataType; + private final boolean isNullable; + + private UpdateColumnType(String[] fieldNames, DataType newDataType, boolean isNullable) { + this.fieldNames = fieldNames; + this.newDataType = newDataType; + this.isNullable = isNullable; + } + + public String[] fieldNames() { + return fieldNames; + } + + public DataType newDataType() { + return newDataType; + } + + public boolean isNullable() { + return isNullable; + } + } + + /** + * A TableChange to update the comment of a field. + *

+ * The field names are used to find the field to update. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class UpdateColumnComment implements TableChange { + private final String[] fieldNames; + private final String newComment; + + private UpdateColumnComment(String[] fieldNames, String newComment) { + this.fieldNames = fieldNames; + this.newComment = newComment; + } + + public String[] fieldNames() { + return fieldNames; + } + + public String newComment() { + return newComment; + } + } + + /** + * A TableChange to delete a field. + *

+ * If the field does not exist, the change must result in an {@link IllegalArgumentException}. + */ + final class DeleteColumn implements TableChange { + private final String[] fieldNames; + + private DeleteColumn(String[] fieldNames) { + this.fieldNames = fieldNames; + } + + public String[] fieldNames() { + return fieldNames; + } + } + +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java index 009e89bd4eb60..7b264e7480e17 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java @@ -40,7 +40,7 @@ private Expressions() { * @param args expression arguments to the transform * @return a logical transform */ - public Transform apply(String name, Expression... args) { + public static Transform apply(String name, Expression... args) { return LogicalExpressions.apply(name, JavaConverters.asScalaBuffer(Arrays.asList(args)).toSeq()); } @@ -51,7 +51,7 @@ public Transform apply(String name, Expression... args) { * @param name a column name * @return a named reference for the column */ - public NamedReference column(String name) { + public static NamedReference column(String name) { return LogicalExpressions.reference(name); } @@ -65,7 +65,7 @@ public NamedReference column(String name) { * @param the JVM type of the value * @return a literal expression for the value */ - public Literal literal(T value) { + public static Literal literal(T value) { return LogicalExpressions.literal(value); } @@ -81,7 +81,7 @@ public Literal literal(T value) { * @param columns input columns for the bucket transform * @return a logical bucket transform with name "bucket" */ - public Transform bucket(int numBuckets, String... columns) { + public static Transform bucket(int numBuckets, String... columns) { return LogicalExpressions.bucket(numBuckets, JavaConverters.asScalaBuffer(Arrays.asList(columns)).toSeq()); } @@ -96,7 +96,7 @@ public Transform bucket(int numBuckets, String... columns) { * @param column an input column * @return a logical identity transform with name "identity" */ - public Transform identity(String column) { + public static Transform identity(String column) { return LogicalExpressions.identity(column); } @@ -110,7 +110,7 @@ public Transform identity(String column) { * @param column an input timestamp or date column * @return a logical yearly transform with name "years" */ - public Transform years(String column) { + public static Transform years(String column) { return LogicalExpressions.years(column); } @@ -125,7 +125,7 @@ public Transform years(String column) { * @param column an input timestamp or date column * @return a logical monthly transform with name "months" */ - public Transform months(String column) { + public static Transform months(String column) { return LogicalExpressions.months(column); } @@ -140,7 +140,7 @@ public Transform months(String column) { * @param column an input timestamp or date column * @return a logical daily transform with name "days" */ - public Transform days(String column) { + public static Transform days(String column) { return LogicalExpressions.days(column); } @@ -155,7 +155,7 @@ public Transform days(String column) { * @param column an input timestamp column * @return a logical hourly transform with name "hours" */ - public Transform hours(String column) { + public static Transform hours(String column) { return LogicalExpressions.hours(column); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java similarity index 72% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java index 78f979a2a9a44..482d3c22e2306 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/Table.java @@ -18,8 +18,11 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.catalog.v2.expressions.Transform; import org.apache.spark.sql.types.StructType; +import java.util.Collections; +import java.util.Map; import java.util.Set; /** @@ -29,6 +32,10 @@ *

* This interface can mixin the following interfaces to support different operations, like * {@code SupportsRead}. + *

+ * The default implementation of {@link #partitioning()} returns an empty array of partitions, and + * the default implementation of {@link #properties()} returns an empty map. These should be + * overridden by implementations that support partitioning and table properties. */ @Evolving public interface Table { @@ -45,6 +52,20 @@ public interface Table { */ StructType schema(); + /** + * Returns the physical partitioning of this table. + */ + default Transform[] partitioning() { + return new Transform[0]; + } + + /** + * Returns the string map of table properties. + */ + default Map properties() { + return Collections.emptyMap(); + } + /** * Returns the set of capabilities for this table. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java similarity index 92% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java index 8d3fdcd694e2c..5a9b85e6d0361 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -47,7 +47,7 @@ public enum TableCapability { *

* Truncating a table removes all existing rows. *

- * See {@link org.apache.spark.sql.sources.v2.writer.SupportsTruncate}. + * See {@code org.apache.spark.sql.sources.v2.writer.SupportsTruncate}. */ TRUNCATE, @@ -55,7 +55,7 @@ public enum TableCapability { * Signals that the table can replace existing data that matches a filter with appended data in * a write operation. *

- * See {@link org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}. + * See {@code org.apache.spark.sql.sources.v2.writer.SupportsOverwrite}. */ OVERWRITE_BY_FILTER, @@ -63,7 +63,7 @@ public enum TableCapability { * Signals that the table can dynamically replace existing data partitions with appended data in * a write operation. *

- * See {@link org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. + * See {@code org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. */ OVERWRITE_DYNAMIC } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala new file mode 100644 index 0000000000000..f512cd5e23c6b --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/CatalogV2Implicits.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.expressions.{BucketTransform, IdentityTransform, LogicalExpressions, Transform} +import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.types.StructType + +/** + * Conversion helpers for working with v2 [[CatalogPlugin]]. + */ +object CatalogV2Implicits { + implicit class PartitionTypeHelper(partitionType: StructType) { + def asTransforms: Array[Transform] = partitionType.names.map(LogicalExpressions.identity) + } + + implicit class BucketSpecHelper(spec: BucketSpec) { + def asTransform: BucketTransform = { + if (spec.sortColumnNames.nonEmpty) { + throw new AnalysisException( + s"Cannot convert bucketing with sort columns to a transform: $spec") + } + + LogicalExpressions.bucket(spec.numBuckets, spec.bucketColumnNames: _*) + } + } + + implicit class TransformHelper(transforms: Seq[Transform]) { + def asPartitionColumns: Seq[String] = { + val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform]) + + if (nonIdTransforms.nonEmpty) { + throw new AnalysisException("Transforms cannot be converted to partition columns: " + + nonIdTransforms.map(_.describe).mkString(", ")) + } + + idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref => + val parts = ref.fieldNames + if (parts.size > 1) { + throw new AnalysisException(s"Cannot partition by nested column: $ref") + } else { + parts(0) + } + } + } + } + + implicit class CatalogHelper(plugin: CatalogPlugin) { + def asTableCatalog: TableCatalog = plugin match { + case tableCatalog: TableCatalog => + tableCatalog + case _ => + throw new AnalysisException(s"Cannot use catalog ${plugin.name}: not a TableCatalog") + } + } + + implicit class NamespaceHelper(namespace: Array[String]) { + def quoted: String = namespace.map(quote).mkString(".") + } + + implicit class IdentifierHelper(ident: Identifier) { + def quoted: String = { + if (ident.namespace.nonEmpty) { + ident.namespace.map(quote).mkString(".") + "." + quote(ident.name) + } else { + quote(ident.name) + } + } + } + + implicit class MultipartIdentifierHelper(namespace: Seq[String]) { + def quoted: String = namespace.map(quote).mkString(".") + } + + private def quote(part: String): String = { + if (part.contains(".") || part.contains("`")) { + s"`${part.replace("`", "``")}`" + } else { + part + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala index 813d88255c6a2..2d4d6e7c6d5ee 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/expressions/expressions.scala @@ -17,9 +17,7 @@ package org.apache.spark.sql.catalog.v2.expressions -import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst -import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{DataType, IntegerType, StringType} @@ -35,38 +33,6 @@ private[sql] object LogicalExpressions { // because this is only used for field names, the SQL conf passed in does not matter. private lazy val parser = new CatalystSqlParser(SQLConf.get) - def fromPartitionColumns(columns: String*): Array[IdentityTransform] = - columns.map(identity).toArray - - def fromBucketSpec(spec: BucketSpec): BucketTransform = { - if (spec.sortColumnNames.nonEmpty) { - throw new AnalysisException( - s"Cannot convert bucketing with sort columns to a transform: $spec") - } - - bucket(spec.numBuckets, spec.bucketColumnNames: _*) - } - - implicit class TransformHelper(transforms: Seq[Transform]) { - def asPartitionColumns: Seq[String] = { - val (idTransforms, nonIdTransforms) = transforms.partition(_.isInstanceOf[IdentityTransform]) - - if (nonIdTransforms.nonEmpty) { - throw new AnalysisException("Transforms cannot be converted to partition columns: " + - nonIdTransforms.map(_.describe).mkString(", ")) - } - - idTransforms.map(_.asInstanceOf[IdentityTransform]).map(_.reference).map { ref => - val parts = ref.fieldNames - if (parts.size > 1) { - throw new AnalysisException(s"Cannot partition by nested column: $ref") - } else { - parts(0) - } - } - } - } - def literal[T](value: T): LiteralValue[T] = { val internalLit = catalyst.expressions.Literal(value) literal(value, internalLit.dataType) @@ -183,17 +149,10 @@ private[sql] final case class LiteralValue[T](value: T, dataType: DataType) exte } private[sql] final case class FieldReference(parts: Seq[String]) extends NamedReference { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.MultipartIdentifierHelper override def fieldNames: Array[String] = parts.toArray - override def describe: String = parts.map(quote).mkString(".") + override def describe: String = parts.quoted override def toString: String = describe - - private def quote(part: String): String = { - if (part.contains(".") || part.contains("`")) { - s"`${part.replace("`", "``")}`" - } else { - part - } - } } private[sql] object FieldReference { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala index 6d587abd8fd4d..f5e9a146bf359 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/AlreadyExistException.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ +import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec /** @@ -25,13 +27,26 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ class DatabaseAlreadyExistsException(db: String) - extends AnalysisException(s"Database '$db' already exists") + extends NamespaceAlreadyExistsException(s"Database '$db' already exists") -class TableAlreadyExistsException(db: String, table: String) - extends AnalysisException(s"Table or view '$table' already exists in database '$db'") +class NamespaceAlreadyExistsException(message: String) extends AnalysisException(message) { + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' already exists") + } +} + +class TableAlreadyExistsException(message: String) extends AnalysisException(message) { + def this(db: String, table: String) = { + this(s"Table or view '$table' already exists in database '$db'") + } + + def this(tableIdent: Identifier) = { + this(s"Table ${tableIdent.quoted} already exists") + } +} class TempTableAlreadyExistsException(table: String) - extends AnalysisException(s"Temporary view '$table' already exists") + extends TableAlreadyExistsException(s"Temporary view '$table' already exists") class PartitionAlreadyExistsException(db: String, table: String, spec: TablePartitionSpec) extends AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala index 8bf6f69f3b17a..7ac8ae61ed537 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NoSuchItemException.scala @@ -18,6 +18,8 @@ package org.apache.spark.sql.catalyst.analysis import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ +import org.apache.spark.sql.catalog.v2.Identifier import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec @@ -25,10 +27,24 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTypes.TablePartitionSpec * Thrown by a catalog when an item cannot be found. The analyzer will rethrow the exception * as an [[org.apache.spark.sql.AnalysisException]] with the correct position information. */ -class NoSuchDatabaseException(val db: String) extends AnalysisException(s"Database '$db' not found") +class NoSuchDatabaseException( + val db: String) extends NoSuchNamespaceException(s"Database '$db' not found") -class NoSuchTableException(db: String, table: String) - extends AnalysisException(s"Table or view '$table' not found in database '$db'") +class NoSuchNamespaceException(message: String) extends AnalysisException(message) { + def this(namespace: Array[String]) = { + this(s"Namespace '${namespace.quoted}' not found") + } +} + +class NoSuchTableException(message: String) extends AnalysisException(message) { + def this(db: String, table: String) = { + this(s"Table or view '$table' not found in database '$db'") + } + + def this(tableIdent: Identifier) = { + this(s"Table ${tableIdent.quoted} not found") + } +} class NoSuchPartitionException( db: String, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala new file mode 100644 index 0000000000000..9c1b9a3e53de2 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TableCatalogSuite.scala @@ -0,0 +1,657 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import java.util +import java.util.Collections + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TableCatalogSuite extends SparkFunSuite { + import CatalogV2Implicits._ + + private val emptyProps: util.Map[String, String] = Collections.emptyMap[String, String] + private val schema: StructType = new StructType() + .add("id", IntegerType) + .add("data", StringType) + + private def newCatalog(): TableCatalog = { + val newCatalog = new TestTableCatalog + newCatalog.initialize("test", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val testIdent = Identifier.of(Array("`", "."), "test_table") + + test("Catalogs can load the catalog") { + val catalog = newCatalog() + + val conf = new SQLConf + conf.setConfString("spark.sql.catalog.test", catalog.getClass.getName) + + val loaded = Catalogs.load("test", conf) + assert(loaded.getClass == catalog.getClass) + } + + test("listTables") { + val catalog = newCatalog() + val ident1 = Identifier.of(Array("ns"), "test_table_1") + val ident2 = Identifier.of(Array("ns"), "test_table_2") + val ident3 = Identifier.of(Array("ns2"), "test_table_1") + + assert(catalog.listTables(Array("ns")).isEmpty) + + catalog.createTable(ident1, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1)) + assert(catalog.listTables(Array("ns2")).isEmpty) + + catalog.createTable(ident3, schema, Array.empty, emptyProps) + catalog.createTable(ident2, schema, Array.empty, emptyProps) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident1, ident2)) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + + catalog.dropTable(ident1) + + assert(catalog.listTables(Array("ns")).toSet == Set(ident2)) + + catalog.dropTable(ident2) + + assert(catalog.listTables(Array("ns")).isEmpty) + assert(catalog.listTables(Array("ns2")).toSet == Set(ident3)) + } + + test("createTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("`", ".", "test_table")) + assert(table.schema == schema) + assert(table.properties.asScala == Map()) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: with properties") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("property", "value") + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) + assert(parsed == Seq("`", ".", "test_table")) + assert(table.schema == schema) + assert(table.properties == properties) + + assert(catalog.tableExists(testIdent)) + } + + test("createTable: table already exists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + val exc = intercept[TableAlreadyExistsException] { + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + } + + assert(exc.message.contains(table.name())) + assert(exc.message.contains("already exists")) + + assert(catalog.tableExists(testIdent)) + } + + test("tableExists") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + catalog.dropTable(testIdent) + + assert(!catalog.tableExists(testIdent)) + } + + test("loadTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("loadTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.loadTable(testIdent) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("invalidateTable") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + catalog.invalidateTable(testIdent) + + val loaded = catalog.loadTable(testIdent) + + assert(table.name == loaded.name) + assert(table.schema == loaded.schema) + assert(table.properties == loaded.properties) + } + + test("invalidateTable: table does not exist") { + val catalog = newCatalog() + + assert(catalog.tableExists(testIdent) === false) + + catalog.invalidateTable(testIdent) + } + + test("alterTable: add property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-1", "1")) + assert(updated.properties.asScala == Map("prop-1" -> "1")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1")) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add property to existing") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.setProperty("prop-2", "2")) + assert(updated.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map("prop-1" -> "1", "prop-2" -> "2")) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove existing property") { + val catalog = newCatalog() + + val properties = new util.HashMap[String, String]() + properties.put("prop-1", "1") + + val table = catalog.createTable(testIdent, schema, Array.empty, properties) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map("prop-1" -> "1")) + } + + test("alterTable: remove missing property") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.properties.asScala == Map()) + + val updated = catalog.alterTable(testIdent, TableChange.removeProperty("prop-1")) + assert(updated.properties.asScala == Map()) + + val loaded = catalog.loadTable(testIdent) + assert(loaded.properties.asScala == Map()) + + assert(table.properties.asScala == Map()) + } + + test("alterTable: add top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) + + assert(updated.schema == schema.add("ts", TimestampType)) + } + + test("alterTable: add required column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false)) + + assert(updated.schema == schema.add("ts", TimestampType, nullable = false)) + } + + test("alterTable: add column with comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("ts"), TimestampType, false, "comment text")) + + val field = StructField("ts", TimestampType, nullable = false).withComment("comment text") + assert(updated.schema == schema.add(field)) + } + + test("alterTable: add nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.addColumn(Array("point", "z"), DoubleType)) + + val expectedSchema = schema.add("point", pointStruct.add("z", DoubleType)) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add column to primitive field fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.addColumn(Array("data", "ts"), TimestampType)) + } + + assert(exc.getMessage.contains("Not a struct")) + assert(exc.getMessage.contains("data")) + + // the table has not changed + assert(catalog.loadTable(testIdent).schema == schema) + } + + test("alterTable: add field to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.addColumn(Array("missing_col", "new_field"), StringType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: update column data type") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update column data type and nullability") { + val catalog = newCatalog() + + val originalSchema = new StructType() + .add("id", IntegerType, nullable = false) + .add("data", StringType) + val table = catalog.createTable(testIdent, originalSchema, Array.empty, emptyProps) + + assert(table.schema == originalSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("id"), LongType, true)) + + val expectedSchema = new StructType().add("id", LongType).add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: update optional column to required fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType, false)) + } + + assert(exc.getMessage.contains("Cannot change optional column to required")) + assert(exc.getMessage.contains("id")) + } + + test("alterTable: update missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnType(Array("missing_col"), LongType)) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: add comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "comment text") + .add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: replace comment") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) + + val expectedSchema = new StructType() + .add("id", IntegerType, nullable = true, "replacement comment") + .add("data", StringType) + + val updated = catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("id"), "replacement comment")) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: add comment to missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.updateColumnComment(Array("missing_col"), "comment")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: rename top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) + + val expectedSchema = new StructType().add("some_id", IntegerType).add("data", StringType) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first")) + + val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename struct column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point"), "p")) + + val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val expectedSchema = schema.add("p", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: rename missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, + TableChange.renameColumn(Array("missing_col"), "new_name")) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: multiple changes") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.renameColumn(Array("point", "x"), "first"), + TableChange.renameColumn(Array("point", "y"), "second")) + + val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete top-level column") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("id"))) + + val expectedSchema = new StructType().add("data", StringType) + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete nested column") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val updated = catalog.alterTable(testIdent, + TableChange.deleteColumn(Array("point", "y"))) + + val newPointStruct = new StructType().add("x", DoubleType) + val expectedSchema = schema.add("point", newPointStruct) + + assert(updated.schema == expectedSchema) + } + + test("alterTable: delete missing column fails") { + val catalog = newCatalog() + + val table = catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(table.schema == schema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"))) + } + + assert(exc.getMessage.contains("missing_col")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: delete missing nested column fails") { + val catalog = newCatalog() + + val pointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) + val tableSchema = schema.add("point", pointStruct) + + val table = catalog.createTable(testIdent, tableSchema, Array.empty, emptyProps) + + assert(table.schema == tableSchema) + + val exc = intercept[IllegalArgumentException] { + catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"))) + } + + assert(exc.getMessage.contains("z")) + assert(exc.getMessage.contains("Cannot find")) + } + + test("alterTable: table does not exist") { + val catalog = newCatalog() + + val exc = intercept[NoSuchTableException] { + catalog.alterTable(testIdent, TableChange.setProperty("prop", "val")) + } + + assert(exc.message.contains(testIdent.quoted)) + assert(exc.message.contains("not found")) + } + + test("dropTable") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + catalog.createTable(testIdent, schema, Array.empty, emptyProps) + + assert(catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(wasDropped) + assert(!catalog.tableExists(testIdent)) + } + + test("dropTable: table does not exist") { + val catalog = newCatalog() + + assert(!catalog.tableExists(testIdent)) + + val wasDropped = catalog.dropTable(testIdent) + + assert(!wasDropped) + assert(!catalog.tableExists(testIdent)) + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala new file mode 100644 index 0000000000000..7a0b014a85462 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala @@ -0,0 +1,220 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalog.v2 + +import java.util +import java.util.Collections +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.catalog.v2.TableChange.{AddColumn, DeleteColumn, RemoveProperty, RenameColumn, SetProperty, UpdateColumnComment, UpdateColumnType} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class TestTableCatalog extends TableCatalog { + import CatalogV2Implicits._ + + private val tables: util.Map[Identifier, Table] = new ConcurrentHashMap[Identifier, Table]() + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + if (partitions.nonEmpty) { + throw new UnsupportedOperationException( + s"Catalog $name: Partitioned tables are not supported") + } + + val table = InMemoryTable(ident.quoted, schema, properties) + + tables.put(ident, table) + + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + val table = loadTable(ident) + val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes) + val schema = TestTableCatalog.applySchemaChanges(table.schema, changes) + val newTable = InMemoryTable(table.name, schema, properties) + + tables.put(ident, newTable) + + newTable + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined +} + +private object TestTableCatalog { + /** + * Apply properties changes to a map and return the result. + */ + def applyPropertiesChanges( + properties: util.Map[String, String], + changes: Seq[TableChange]): util.Map[String, String] = { + val newProperties = new util.HashMap[String, String](properties) + + changes.foreach { + case set: SetProperty => + newProperties.put(set.property, set.value) + + case unset: RemoveProperty => + newProperties.remove(unset.property) + + case _ => + // ignore non-property changes + } + + Collections.unmodifiableMap(newProperties) + } + + /** + * Apply schema changes to a schema and return the result. + */ + def applySchemaChanges(schema: StructType, changes: Seq[TableChange]): StructType = { + changes.foldLeft(schema) { (schema, change) => + change match { + case add: AddColumn => + add.fieldNames match { + case Array(name) => + val newField = StructField(name, add.dataType, nullable = add.isNullable) + Option(add.comment) match { + case Some(comment) => + schema.add(newField.withComment(comment)) + case _ => + schema.add(newField) + } + + case names => + replace(schema, names.init, parent => parent.dataType match { + case parentType: StructType => + val field = StructField(names.last, add.dataType, nullable = add.isNullable) + val newParentType = Option(add.comment) match { + case Some(comment) => + parentType.add(field.withComment(comment)) + case None => + parentType.add(field) + } + + Some(StructField(parent.name, newParentType, parent.nullable, parent.metadata)) + + case _ => + throw new IllegalArgumentException(s"Not a struct: ${names.init.last}") + }) + } + + case rename: RenameColumn => + replace(schema, rename.fieldNames, field => + Some(StructField(rename.newName, field.dataType, field.nullable, field.metadata))) + + case update: UpdateColumnType => + replace(schema, update.fieldNames, field => { + if (!update.isNullable && field.nullable) { + throw new IllegalArgumentException( + s"Cannot change optional column to required: $field.name") + } + Some(StructField(field.name, update.newDataType, update.isNullable, field.metadata)) + }) + + case update: UpdateColumnComment => + replace(schema, update.fieldNames, field => + Some(field.withComment(update.newComment))) + + case delete: DeleteColumn => + replace(schema, delete.fieldNames, _ => None) + + case _ => + // ignore non-schema changes + schema + } + } + } + + private def replace( + struct: StructType, + path: Seq[String], + update: StructField => Option[StructField]): StructType = { + + val pos = struct.getFieldIndex(path.head) + .getOrElse(throw new IllegalArgumentException(s"Cannot find field: ${path.head}")) + val field = struct.fields(pos) + val replacement: Option[StructField] = if (path.tail.isEmpty) { + update(field) + } else { + field.dataType match { + case nestedStruct: StructType => + val updatedType: StructType = replace(nestedStruct, path.tail, update) + Some(StructField(field.name, updatedType, field.nullable, field.metadata)) + case _ => + throw new IllegalArgumentException(s"Not a struct: ${path.head}") + } + } + + val newFields = struct.fields.zipWithIndex.flatMap { + case (_, index) if pos == index => + replacement + case (other, _) => + Some(other) + } + + new StructType(newFields) + } +} + +case class InMemoryTable( + name: String, + schema: StructType, + override val properties: util.Map[String, String]) extends Table { + override def partitioning: Array[Transform] = Array.empty + override def capabilities: util.Set[TableCapability] = InMemoryTable.CAPABILITIES +} + +object InMemoryTable { + val CAPABILITIES: util.Set[TableCapability] = Set.empty[TableCapability].asJava +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index f503ff03b971c..30ecad642dc16 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.StructType case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - import org.apache.spark.sql.catalog.v2.expressions.LogicalExpressions.TransformHelper + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.TransformHelper override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index c00e65b07312f..deada9a83964b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -16,11 +16,14 @@ */ package org.apache.spark.sql.execution.datasources.v2 +import java.util + import scala.collection.JavaConverters._ import org.apache.hadoop.fs.FileStatus import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.TableCapability._ @@ -34,6 +37,8 @@ abstract class FileTable( userSpecifiedSchema: Option[StructType]) extends Table with SupportsRead with SupportsWrite { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + lazy val fileIndex: PartitioningAwareFileIndex = { val scalaMap = options.asScala.toMap val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(scalaMap) @@ -59,7 +64,11 @@ abstract class FileTable( fileIndex.partitionSchema, caseSensitive)._1 } - override def capabilities(): java.util.Set[TableCapability] = FileTable.CAPABILITIES + override def partitioning: Array[Transform] = fileIndex.partitionSchema.asTransforms + + override def properties: util.Map[String, String] = options.asCaseSensitiveMap + + override def capabilities: java.util.Set[TableCapability] = FileTable.CAPABILITIES /** * When possible, this method should return the schema of the given `files`. When the format diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java index dfbea927e477b..391af5a306a16 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.util.Arrays; +import org.apache.spark.sql.catalog.v2.expressions.Expressions; +import org.apache.spark.sql.catalog.v2.expressions.Transform; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.sources.v2.Table; @@ -56,6 +58,11 @@ public Partitioning outputPartitioning() { @Override public Table getTable(CaseInsensitiveStringMap options) { return new JavaSimpleBatchTable() { + @Override + public Transform[] partitioning() { + return new Transform[] { Expressions.identity("i") }; + } + @Override public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) { return new MyScanBuilder(); From d2b526cb5bacfad3d1602474ba945f38b363d485 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 15 May 2019 11:24:03 +0800 Subject: [PATCH 29/70] [SPARK-24923][SQL] Implement v2 CreateTableAsSelect This adds a v2 implementation for CTAS queries * Update the SQL parser to parse CREATE queries using multi-part identifiers * Update `CheckAnalysis` to validate partitioning references with the CTAS query schema * Add `CreateTableAsSelect` v2 logical plan and `CreateTableAsSelectExec` v2 physical plan * Update create conversion from `CreateTableAsSelectStatement` to support the new v2 logical plan * Update `DataSourceV2Strategy` to convert v2 CTAS logical plan to the new physical plan * Add `findNestedField` to `StructType` to support reference validation We have been running these changes in production for several months. Also: * Add a test suite `CreateTablePartitioningValidationSuite` for new analysis checks * Add a test suite for v2 SQL, `DataSourceV2SQLSuite` * Update catalyst `DDLParserSuite` to use multi-part identifiers (`Seq[String]`) * Add test cases to `PlanResolutionSuite` for v2 CTAS: known catalog and v2 source implementation Closes #24570 from rdblue/SPARK-24923-add-v2-ctas. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 2 +- .../sql/catalyst/analysis/CheckAnalysis.scala | 17 ++ .../sql/catalyst/parser/AstBuilder.scala | 5 +- .../plans/logical/basicLogicalOperators.scala | 24 ++ .../logical/sql/CreateTableStatement.scala | 5 +- .../apache/spark/sql/types/StructType.scala | 23 ++ .../sql/catalog/v2/TestTableCatalog.scala | 2 +- ...eateTablePartitioningValidationSuite.scala | 153 ++++++++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 23 +- .../spark/sql/execution/SparkSqlParser.scala | 28 ++- .../datasources/DataSourceResolution.scala | 82 ++++++- .../datasources/v2/DataSourceV2Strategy.scala | 9 +- .../v2/WriteToDataSourceV2Exec.scala | 58 +++++ .../internal/BaseSessionStateBuilder.scala | 2 +- .../command/PlanResolutionSuite.scala | 92 ++++++- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 166 +++++++++++++ .../sources/v2/TestInMemoryTableCatalog.scala | 229 ++++++++++++++++++ .../sql/hive/HiveSessionStateBuilder.scala | 2 +- 18 files changed, 888 insertions(+), 34 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index d9caea170bc47..cce108b0e1f03 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -237,7 +237,7 @@ unsupportedHiveNativeCommands ; createTableHeader - : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? tableIdentifier + : CREATE TEMPORARY? EXTERNAL? TABLE (IF NOT EXISTS)? multipartIdentifier ; bucketSpec diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala index 18c40b370cb5f..fcb2eec609c28 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala @@ -33,6 +33,8 @@ import org.apache.spark.sql.types._ */ trait CheckAnalysis extends PredicateHelper { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + /** * Override to provide additional checks for correct analysis. * These rules will be evaluated after our built-in check rules. @@ -296,6 +298,21 @@ trait CheckAnalysis extends PredicateHelper { } } + case CreateTableAsSelect(_, _, partitioning, query, _, _, _) => + val references = partitioning.flatMap(_.references).toSet + val badReferences = references.map(_.fieldNames).flatMap { column => + query.schema.findNestedField(column) match { + case Some(_) => + None + case _ => + Some(s"${column.quoted} is missing or is in a map or array") + } + } + + if (badReferences.nonEmpty) { + failAnalysis(s"Invalid partitioning: ${badReferences.mkString(", ")}") + } + case _ => // Fallbacks to the following checks } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index b18959f2e972a..270c99d6cca8c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -1975,7 +1975,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging /** * Type to keep track of a table header: (identifier, isTemporary, ifNotExists, isExternal). */ - type TableHeader = (TableIdentifier, Boolean, Boolean, Boolean) + type TableHeader = (Seq[String], Boolean, Boolean, Boolean) /** * Validate a create table statement and return the [[TableIdentifier]]. @@ -1987,7 +1987,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging if (temporary && ifNotExists) { operationNotAllowed("CREATE TEMPORARY TABLE ... IF NOT EXISTS", ctx) } - (visitTableIdentifier(ctx.tableIdentifier), temporary, ifNotExists, ctx.EXTERNAL != null) + val multipartIdentifier = ctx.multipartIdentifier.parts.asScala.map(_.getText) + (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index f7f701cea51fb..2bbe0bb006897 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.AliasIdentifier import org.apache.spark.sql.catalyst.analysis.{MultiInstanceRelation, NamedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable} @@ -387,6 +389,28 @@ trait V2WriteCommand extends Command { } } +/** + * Create a new table from a select query with a v2 catalog. + */ +case class CreateTableAsSelect( + catalog: TableCatalog, + tableName: Identifier, + partitioning: Seq[Transform], + query: LogicalPlan, + properties: Map[String, String], + writeOptions: Map[String, String], + ignoreIfExists: Boolean) extends Command { + + override def children: Seq[LogicalPlan] = Seq(query) + + override lazy val resolved: Boolean = { + // the table schema is created from the query schema, so the only resolution needed is to check + // that the columns referenced by the table's partitioning exist in the query schema + val references = partitioning.flatMap(_.references).toSet + references.map(_.fieldNames).forall(query.schema.findNestedField(_).isDefined) + } +} + /** * Append data to an existing table. */ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala index ed1b3e3778c7f..7a26e01cde830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.plans.logical.sql import org.apache.spark.sql.catalog.v2.expressions.Transform -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -30,7 +29,7 @@ import org.apache.spark.sql.types.StructType * This is a metadata-only command and is not used to write data to the created table. */ case class CreateTableStatement( - table: TableIdentifier, + tableName: Seq[String], tableSchema: StructType, partitioning: Seq[Transform], bucketSpec: Option[BucketSpec], @@ -50,7 +49,7 @@ case class CreateTableStatement( * A CREATE TABLE AS SELECT command, as parsed from SQL. */ case class CreateTableAsSelectStatement( - table: TableIdentifier, + tableName: Seq[String], asSelect: LogicalPlan, partitioning: Seq[Transform], bucketSpec: Option[BucketSpec], diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala index d563276a5711d..c472bd8ee84b9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/StructType.scala @@ -307,6 +307,29 @@ case class StructType(fields: Array[StructField]) extends DataType with Seq[Stru nameToIndex.get(name) } + /** + * Returns a field in this struct and its child structs. + * + * This does not support finding fields nested in maps or arrays. + */ + private[sql] def findNestedField(fieldNames: Seq[String]): Option[StructField] = { + fieldNames.headOption.flatMap(nameToField.get) match { + case Some(field) => + if (fieldNames.tail.isEmpty) { + Some(field) + } else { + field.dataType match { + case struct: StructType => + struct.findNestedField(fieldNames.tail) + case _ => + None + } + } + case _ => + None + } + } + protected[sql] def toAttributes: Seq[AttributeReference] = map(f => AttributeReference(f.name, f.dataType, f.nullable, f.metadata)()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala index 7a0b014a85462..78b4763484cc0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalog/v2/TestTableCatalog.scala @@ -91,7 +91,7 @@ class TestTableCatalog extends TableCatalog { override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined } -private object TestTableCatalog { +object TestTableCatalog { /** * Apply properties changes to a map and return the result. */ diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala new file mode 100644 index 0000000000000..1ce8852f71bc8 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/CreateTablePartitioningValidationSuite.scala @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.analysis + +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.LogicalExpressions +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LeafNode} +import org.apache.spark.sql.types.{DoubleType, LongType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class CreateTablePartitioningValidationSuite extends AnalysisTest { + import CreateTablePartitioningValidationSuite._ + + test("CreateTableAsSelect: fail missing top-level column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "does_not_exist is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail missing top-level column nested reference") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "does_not_exist.z is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail missing nested column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "point.z is missing or is in a map or array")) + } + + test("CreateTableAsSelect: fail with multiple errors") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "does_not_exist", "point.z") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assert(!plan.resolved) + assertAnalysisError(plan, Seq( + "Invalid partitioning", + "point.z is missing or is in a map or array", + "does_not_exist is missing or is in a map or array")) + } + + test("CreateTableAsSelect: success with top-level column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "id") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } + + test("CreateTableAsSelect: success using nested column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point.x") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } + + test("CreateTableAsSelect: success using complex column") { + val plan = CreateTableAsSelect( + catalog, + Identifier.of(Array(), "table_name"), + LogicalExpressions.bucket(4, "point") :: Nil, + TestRelation2, + Map.empty, + Map.empty, + ignoreIfExists = false) + + assertAnalysisSuccess(plan) + } +} + +private object CreateTablePartitioningValidationSuite { + val catalog: TableCatalog = { + val cat = new TestTableCatalog() + cat.initialize("test", CaseInsensitiveStringMap.empty()) + cat + } + + val schema: StructType = new StructType() + .add("id", LongType) + .add("data", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType)) +} + +private case object TestRelation2 extends LeafNode with NamedRelation { + override def name: String = "source_relation" + override def output: Seq[AttributeReference] = + CreateTablePartitioningValidationSuite.schema.toAttributes +} + diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 98388a74cd29d..08baebbf140e6 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} @@ -40,7 +39,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType() .add("a", IntegerType, nullable = true, "test") .add("b", StringType)) @@ -67,7 +66,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -90,7 +89,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(query) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType() .add("a", IntegerType, nullable = true, "test") .add("b", StringType)) @@ -125,7 +124,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType() .add("a", IntegerType) .add("b", StringType) @@ -161,7 +160,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(query) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.contains(BucketSpec(5, Seq("a"), Seq("b")))) @@ -183,7 +182,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -205,7 +204,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -227,7 +226,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("my_tab")) + assert(create.tableName == Seq("my_tab")) assert(create.tableSchema == new StructType().add("a", IntegerType).add("b", StringType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -249,7 +248,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("2g", Some("1m"))) + assert(create.tableName == Seq("1m", "2g")) assert(create.tableSchema == new StructType().add("a", IntegerType)) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -292,7 +291,7 @@ class DDLParserSuite extends AnalysisTest { parsePlan(sql) match { case create: CreateTableStatement => - assert(create.table == TableIdentifier("table_name")) + assert(create.tableName == Seq("table_name")) assert(create.tableSchema == new StructType) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) @@ -347,7 +346,7 @@ class DDLParserSuite extends AnalysisTest { def checkParsing(sql: String): Unit = { parsePlan(sql) match { case create: CreateTableAsSelectStatement => - assert(create.table == TableIdentifier("page_view", Some("mydb"))) + assert(create.tableName == Seq("mydb", "page_view")) assert(create.partitioning.isEmpty) assert(create.bucketSpec.isEmpty) assert(create.properties == Map("p1" -> "v1", "p2" -> "v2")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index b997399007cd5..dd273937f9788 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -25,7 +25,7 @@ import org.antlr.v4.runtime.{ParserRuleContext, Token} import org.antlr.v4.runtime.tree.TerminalNode import org.apache.spark.sql.SaveMode -import org.apache.spark.sql.catalyst.FunctionIdentifier +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ @@ -376,6 +376,25 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { DescribeQueryCommand(visitQueryToDesc(ctx.queryToDesc())) } + /** + * Converts a multi-part identifier to a TableIdentifier. + * + * If the multi-part identifier has too many parts, this will throw a ParseException. + */ + def tableIdentifier( + multipart: Seq[String], + command: String, + ctx: ParserRuleContext): TableIdentifier = { + multipart match { + case Seq(tableName) => + TableIdentifier(tableName) + case Seq(database, tableName) => + TableIdentifier(tableName, Some(database)) + case _ => + operationNotAllowed(s"$command does not support multi-part identifiers", ctx) + } + } + /** * Create a table, returning a [[CreateTable]] logical plan. * @@ -386,7 +405,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * it is deprecated. */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { - val (table, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) if (!temp || ctx.query != null) { super.visitCreateTable(ctx) @@ -415,6 +434,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { logWarning(s"CREATE TEMPORARY TABLE ... USING ... is deprecated, please use " + "CREATE TEMPORARY VIEW ... USING ... instead") + val table = tableIdentifier(ident, "CREATE TEMPORARY VIEW", ctx) CreateTempViewUsing(table, schema, replace = false, global = false, provider, options) } } @@ -948,7 +968,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { - val (name, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) + val (ident, temp, ifNotExists, external) = visitCreateTableHeader(ctx.createTableHeader) // TODO: implement temporary tables if (temp) { throw new ParseException( @@ -1006,6 +1026,8 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { CatalogTableType.MANAGED } + val name = tableIdentifier(ident, "CREATE TABLE ... STORED AS ...", ctx) + // TODO support the sql text - have a proper location for this! val tableDesc = CatalogTable( identifier = name, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 30ecad642dc16..6d1cbe18c900c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -19,25 +19,34 @@ package org.apache.spark.sql.execution.datasources import java.util.Locale +import scala.collection.mutable + import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.{CatalogPlugin, Identifier, LookupCatalog, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.StructType -case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with CastSupport { - import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.TransformHelper +case class DataSourceResolution( + conf: SQLConf, + findCatalog: String => CatalogPlugin) + extends Rule[LogicalPlan] with CastSupport with LookupCatalog { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( - table, schema, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, - location, comment, ifNotExists) => + AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties, + V1WriteProvider(provider), options, location, comment, ifNotExists) => val tableDesc = buildCatalogTable(table, schema, partitionCols, bucketSpec, properties, provider, options, location, comment, ifNotExists) @@ -46,14 +55,23 @@ case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with Ca CreateTable(tableDesc, mode, None) case CreateTableAsSelectStatement( - table, query, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, - location, comment, ifNotExists) => + AsTableIdentifier(table), query, partitionCols, bucketSpec, properties, + V1WriteProvider(provider), options, location, comment, ifNotExists) => val tableDesc = buildCatalogTable(table, new StructType, partitionCols, bucketSpec, properties, provider, options, location, comment, ifNotExists) val mode = if (ifNotExists) SaveMode.Ignore else SaveMode.ErrorIfExists CreateTable(tableDesc, mode, Some(query)) + + case create: CreateTableAsSelectStatement => + // the provider was not a v1 source, convert to a v2 plan + val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName + val catalog = maybeCatalog + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) + .asTableCatalog + convertCTAS(catalog, identifier, create) } object V1WriteProvider { @@ -112,4 +130,54 @@ case class DataSourceResolution(conf: SQLConf) extends Rule[LogicalPlan] with Ca properties = properties, comment = comment) } + + private def convertCTAS( + catalog: TableCatalog, + identifier: Identifier, + ctas: CreateTableAsSelectStatement): CreateTableAsSelect = { + if (ctas.options.contains("path") && ctas.location.isDefined) { + throw new AnalysisException( + "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + + "you can only specify one of them.") + } + + if ((ctas.options.contains("provider") || ctas.properties.contains("provider")) + && ctas.comment.isDefined) { + throw new AnalysisException( + "COMMENT and option/property 'comment' are both used to set the table comment, you can " + + "only specify one of them.") + } + + if (ctas.options.contains("provider") || ctas.properties.contains("provider")) { + throw new AnalysisException( + "USING and option/property 'provider' are both used to set the provider implementation, " + + "you can only specify one of them.") + } + + val options = ctas.options.filterKeys(_ != "path") + + // convert the bucket spec and add it as a transform + val partitioning = ctas.partitioning ++ ctas.bucketSpec.map(_.asTransform) + + // create table properties from TBLPROPERTIES and OPTIONS clauses + val properties = new mutable.HashMap[String, String]() + properties ++= ctas.properties + properties ++= options + + // convert USING, LOCATION, and COMMENT clauses to table properties + properties += ("provider" -> ctas.provider) + ctas.comment.map(text => properties += ("comment" -> text)) + ctas.location + .orElse(ctas.options.get("path")) + .map(loc => properties += ("location" -> loc)) + + CreateTableAsSelect( + catalog, + identifier, + partitioning, + ctas.asSelect, + properties.toMap, + writeOptions = options, + ignoreIfExists = ctas.ifNotExists) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index b4b21e1b6d69e..165553c4da5bb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -17,18 +17,20 @@ package org.apache.spark.sql.execution.datasources.v2 +import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} import org.apache.spark.sql.sources import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} +import org.apache.spark.sql.util.CaseInsensitiveStringMap object DataSourceV2Strategy extends Strategy with PredicateHelper { @@ -146,6 +148,11 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => + val writeOptions = new CaseInsensitiveStringMap(options.asJava) + CreateTableAsSelectExec( + catalog, ident, parts, planLater(query), props, writeOptions, ifNotExists) :: Nil + case AppendData(r: DataSourceV2Relation, query, _) => AppendDataExec(r.table.asWritable, r.options, planLater(query)) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 607f2fa0f82c8..1797166bbe0b0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.v2 import java.util.UUID +import scala.collection.JavaConverters._ import scala.util.control.NonFatal import org.apache.spark.{SparkEnv, SparkException, TaskContext} @@ -26,7 +27,10 @@ import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD import org.apache.spark.sql.SaveMode +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} @@ -47,6 +51,60 @@ case class WriteToDataSourceV2(batchWrite: BatchWrite, query: LogicalPlan) override def output: Seq[Attribute] = Nil } +/** + * Physical plan node for v2 create table as select. + * + * A new table will be created using the schema of the query, and rows from the query are appended. + * If either table creation or the append fails, the table will be deleted. This implementation does + * not provide an atomic CTAS. + */ +case class CreateTableAsSelectExec( + catalog: TableCatalog, + ident: Identifier, + partitioning: Seq[Transform], + query: SparkPlan, + properties: Map[String, String], + writeOptions: CaseInsensitiveStringMap, + ifNotExists: Boolean) extends V2TableWriteExec { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits.IdentifierHelper + + override protected def doExecute(): RDD[InternalRow] = { + if (catalog.tableExists(ident)) { + if (ifNotExists) { + return sparkContext.parallelize(Seq.empty, 1) + } + + throw new TableAlreadyExistsException(ident) + } + + Utils.tryWithSafeFinallyAndFailureCallbacks({ + catalog.createTable(ident, query.schema, partitioning.toArray, properties.asJava) match { + case table: SupportsWrite => + val builder = table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + val batchWrite = builder match { + case supportsSaveMode: SupportsSaveMode => + supportsSaveMode.mode(SaveMode.Append).buildForBatch() + + case _ => + builder.buildForBatch() + } + + doWrite(batchWrite) + + case _ => + // table does not support writes + throw new SparkException(s"Table implementation does not support writes: ${ident.quoted}") + } + + })(catchBlock = { + catalog.dropTable(ident) + }) + } +} + /** * Physical plan node for append into a v2 table. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index d5543e8a31aad..588f5dde85930 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -161,7 +161,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: - DataSourceResolution(conf) +: + DataSourceResolution(conf, session.catalog(_)) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 7fae54bb95ed1..c525b4cbcba57 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -20,19 +20,39 @@ package org.apache.spark.sql.execution.command import java.net.URI import org.apache.spark.sql.{AnalysisException, SaveMode} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, TableCatalog, TestTableCatalog} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LogicalPlan} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.util.CaseInsensitiveStringMap class PlanResolutionSuite extends AnalysisTest { import CatalystSqlParser._ + private val orc2 = classOf[OrcDataSourceV2].getName + + private val testCat: TableCatalog = { + val newCatalog = new TestTableCatalog + newCatalog.initialize("testcat", CaseInsensitiveStringMap.empty()) + newCatalog + } + + private val lookupCatalog: String => CatalogPlugin = { + case "testcat" => + testCat + case name => + throw new CatalogNotFoundException(s"No such catalog: $name") + } + def parseAndResolve(query: String): LogicalPlan = { - DataSourceResolution(conf).apply(parsePlan(query)) + val newConf = conf.copy() + newConf.setConfString("spark.sql.default.catalog", "testcat") + DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query)) } private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { @@ -274,4 +294,72 @@ class PlanResolutionSuite extends AnalysisTest { assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) } } + + test("Test v2 CTAS with known catalog in identifier") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + |AS SELECT * FROM src + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql) match { + case ctas: CreateTableAsSelect => + assert(ctas.catalog.name == "testcat") + assert(ctas.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(ctas.properties == expectedProperties) + assert(ctas.writeOptions == Map("other" -> "20")) + assert(ctas.partitioning.isEmpty) + assert(ctas.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + // TODO(rblue): enable this test after the default catalog is available + ignore("Test v2 CTAS with data source v2 provider") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING $orc2 + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "provider" -> orc2, + "location" -> "/user/external/page_view", + "comment" -> "This is the staging page view table") + + parseAndResolve(sql) match { + case ctas: CreateTableAsSelect => + assert(ctas.catalog.name == "testcat") + assert(ctas.tableName == Identifier.of(Array("mydb"), "page_view")) + assert(ctas.properties == expectedProperties) + assert(ctas.writeOptions.isEmpty) + assert(ctas.partitioning.isEmpty) + assert(ctas.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateTableAsSelect].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala new file mode 100644 index 0000000000000..a9bc0369ad20f --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import scala.collection.JavaConverters._ + +import org.scalatest.BeforeAndAfter + +import org.apache.spark.sql.{AnalysisException, QueryTest} +import org.apache.spark.sql.catalog.v2.Identifier +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{LongType, StringType, StructType} + +class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAndAfter { + + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + private val orc2 = classOf[OrcDataSourceV2].getName + + before { + spark.conf.set("spark.sql.catalog.testcat", classOf[TestInMemoryTableCatalog].getName) + spark.conf.set("spark.sql.default.catalog", "testcat") + + val df = spark.createDataFrame(Seq((1L, "a"), (2L, "b"), (3L, "c"))).toDF("id", "data") + df.createOrReplaceTempView("source") + val df2 = spark.createDataFrame(Seq((4L, "d"), (5L, "e"), (6L, "f"))).toDF("id", "data") + df2.createOrReplaceTempView("source2") + } + + after { + spark.catalog("testcat").asInstanceOf[TestInMemoryTableCatalog].clearTables() + spark.sql("DROP TABLE source") + } + + test("CreateTableAsSelect: use v2 plan because catalog is set") { + spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + } + + // TODO(rblue): enable this test after the default catalog is available + ignore("CreateTableAsSelect: use v2 plan because provider is v2") { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> orc2).asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: fail if table exists") { + spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + + // run a second CTAS query that should fail + val exc = intercept[TableAlreadyExistsException] { + spark.sql( + "CREATE TABLE testcat.table_name USING bar AS SELECT id, data, id as id2 FROM source2") + } + + assert(exc.getMessage.contains("table_name")) + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: if not exists") { + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType() + .add("id", LongType, nullable = false) + .add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) + + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name USING foo AS SELECT id, data FROM source2") + + // check that the table contains data from just the first CTAS + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), spark.table("source")) + } + + test("CreateTableAsSelect: fail analysis when default catalog is needed but missing") { + val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") + try { + conf.unsetConf("spark.sql.default.catalog") + + val exc = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + } + + assert(exc.getMessage.contains("No catalog specified for table")) + assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("no default catalog is set")) + + } finally { + conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala new file mode 100644 index 0000000000000..2ecf1c2f184fb --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -0,0 +1,229 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.sources.v2 + +import java.util +import java.util.concurrent.ConcurrentHashMap + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import org.apache.spark.sql.catalog.v2.{CatalogV2Implicits, Identifier, TableCatalog, TableChange, TestTableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} +import org.apache.spark.sql.sources.v2.reader.{Batch, InputPartition, PartitionReader, PartitionReaderFactory, Scan, ScanBuilder} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriter, DataWriterFactory, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +// this is currently in the spark-sql module because the read and write API is not in catalyst +// TODO(rdblue): when the v2 source API is in catalyst, merge with TestTableCatalog/InMemoryTable +class TestInMemoryTableCatalog extends TableCatalog { + import CatalogV2Implicits._ + + private val tables: util.Map[Identifier, InMemoryTable] = + new ConcurrentHashMap[Identifier, InMemoryTable]() + private var _name: Option[String] = None + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = { + _name = Some(name) + } + + override def name: String = _name.get + + override def listTables(namespace: Array[String]): Array[Identifier] = { + tables.keySet.asScala.filter(_.namespace.sameElements(namespace)).toArray + } + + override def loadTable(ident: Identifier): Table = { + Option(tables.get(ident)) match { + case Some(table) => + table + case _ => + throw new NoSuchTableException(ident) + } + } + + override def createTable( + ident: Identifier, + schema: StructType, + partitions: Array[Transform], + properties: util.Map[String, String]): Table = { + + if (tables.containsKey(ident)) { + throw new TableAlreadyExistsException(ident) + } + + if (partitions.nonEmpty) { + throw new UnsupportedOperationException( + s"Catalog $name: Partitioned tables are not supported") + } + + val table = new InMemoryTable(s"$name.${ident.quoted}", schema, properties) + + tables.put(ident, table) + + table + } + + override def alterTable(ident: Identifier, changes: TableChange*): Table = { + Option(tables.get(ident)) match { + case Some(table) => + val properties = TestTableCatalog.applyPropertiesChanges(table.properties, changes) + val schema = TestTableCatalog.applySchemaChanges(table.schema, changes) + val newTable = new InMemoryTable(table.name, schema, properties, table.data) + + tables.put(ident, newTable) + + newTable + case _ => + throw new NoSuchTableException(ident) + } + } + + override def dropTable(ident: Identifier): Boolean = Option(tables.remove(ident)).isDefined + + def clearTables(): Unit = { + tables.clear() + } +} + +/** + * A simple in-memory table. Rows are stored as a buffered group produced by each output task. + */ +private class InMemoryTable( + val name: String, + val schema: StructType, + override val properties: util.Map[String, String]) + extends Table with SupportsRead with SupportsWrite { + + def this( + name: String, + schema: StructType, + properties: util.Map[String, String], + data: Array[BufferedRows]) = { + this(name, schema, properties) + replaceData(data) + } + + def rows: Seq[InternalRow] = data.flatMap(_.rows) + + @volatile var data: Array[BufferedRows] = Array.empty + + def replaceData(buffers: Array[BufferedRows]): Unit = synchronized { + data = buffers + } + + override def capabilities: util.Set[TableCapability] = Set( + TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava + + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + () => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) + } + + class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch { + override def readSchema(): StructType = schema + + override def toBatch: Batch = this + + override def planInputPartitions(): Array[InputPartition] = data + + override def createReaderFactory(): PartitionReaderFactory = BufferedRowsReaderFactory + } + + override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { + new WriteBuilder with SupportsTruncate { + private var shouldTruncate: Boolean = false + + override def truncate(): WriteBuilder = { + shouldTruncate = true + this + } + + override def buildForBatch(): BatchWrite = { + if (shouldTruncate) TruncateAndAppend else Append + } + } + } + + private object TruncateAndAppend extends BatchWrite { + override def createBatchWriterFactory(): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + replaceData(messages.map(_.asInstanceOf[BufferedRows])) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + } + } + + private object Append extends BatchWrite { + override def createBatchWriterFactory(): DataWriterFactory = { + BufferedRowsWriterFactory + } + + override def commit(messages: Array[WriterCommitMessage]): Unit = { + replaceData(data ++ messages.map(_.asInstanceOf[BufferedRows])) + } + + override def abort(messages: Array[WriterCommitMessage]): Unit = { + } + } +} + +private class BufferedRows extends WriterCommitMessage with InputPartition with Serializable { + val rows = new mutable.ArrayBuffer[InternalRow]() +} + +private object BufferedRowsReaderFactory extends PartitionReaderFactory { + override def createReader(partition: InputPartition): PartitionReader[InternalRow] = { + new BufferedRowsReader(partition.asInstanceOf[BufferedRows]) + } +} + +private class BufferedRowsReader(partition: BufferedRows) extends PartitionReader[InternalRow] { + private var index: Int = -1 + + override def next(): Boolean = { + index += 1 + index < partition.rows.length + } + + override def get(): InternalRow = partition.rows(index) + + override def close(): Unit = {} +} + +private object BufferedRowsWriterFactory extends DataWriterFactory { + override def createWriter(partitionId: Int, taskId: Long): DataWriter[InternalRow] = { + new BufferWriter + } +} + +private class BufferWriter extends DataWriter[InternalRow] { + private val buffer = new BufferedRows + + override def write(row: InternalRow): Unit = buffer.rows.append(row.copy()) + + override def commit(): WriterCommitMessage = buffer + + override def abort(): Unit = {} +} diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 877a0dadf0b03..23c777ea1030b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -73,7 +73,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallbackOrcDataSourceV2(session) +: - DataSourceResolution(conf) +: + DataSourceResolution(conf, session.catalog(_)) +: customResolutionRules override val postHocResolutionRules: Seq[Rule[LogicalPlan]] = From e133b92c86e93aa58daebec76dd0d12a57c869f2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 15:12:02 -0700 Subject: [PATCH 30/70] Fix scala 2.11 compilation --- .../apache/spark/sql/catalog/v2/expressions/Expressions.java | 4 ++-- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 5 ++++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java index 7b264e7480e17..c4860f56ae76d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java @@ -42,7 +42,7 @@ private Expressions() { */ public static Transform apply(String name, Expression... args) { return LogicalExpressions.apply(name, - JavaConverters.asScalaBuffer(Arrays.asList(args)).toSeq()); + JavaConverters.asScalaBufferConverter(Arrays.asList(args)).asScala()); } /** @@ -83,7 +83,7 @@ public static Literal literal(T value) { */ public static Transform bucket(int numBuckets, String... columns) { return LogicalExpressions.bucket(numBuckets, - JavaConverters.asScalaBuffer(Arrays.asList(columns)).toSeq()); + JavaConverters.asScalaBufferConverter(Arrays.asList(columns)).asScala()); } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index f590c63f80b21..84d951cf6e7d4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.util import java.sql.{Date, Timestamp} import java.text.{DateFormat, SimpleDateFormat} +import java.time.ZoneId import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.ConcurrentHashMap import java.util.function.{Function => JFunction} -import javax.xml.bind.DatatypeConverter +import javax.xml.bind.DatatypeConverter import scala.annotation.tailrec import org.apache.spark.unsafe.types.UTF8String @@ -123,6 +124,8 @@ object DateTimeUtils { override def apply(timeZoneId: String): TimeZone = TimeZone.getTimeZone(timeZoneId) } + def getZoneId(timeZoneId: String): ZoneId = ZoneId.of(timeZoneId, ZoneId.SHORT_IDS) + def getTimeZone(timeZoneId: String): TimeZone = { computedTimeZones.computeIfAbsent(timeZoneId, computeTimeZone) } From 6dbc1d3e01eb2c21c3f2e72bd6ef25069ed7694a Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 15:40:13 -0700 Subject: [PATCH 31/70] Fix style --- .../spark/sql/catalog/v2/expressions/Expressions.java | 7 ++++--- .../org/apache/spark/sql/catalyst/util/DateTimeUtils.scala | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java index c4860f56ae76d..d8e49beb0bca5 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/expressions/Expressions.java @@ -17,11 +17,12 @@ package org.apache.spark.sql.catalog.v2.expressions; -import org.apache.spark.annotation.Experimental; -import org.apache.spark.sql.types.DataType; +import java.util.Arrays; + import scala.collection.JavaConverters; -import java.util.Arrays; +import org.apache.spark.annotation.Experimental; +import org.apache.spark.sql.types.DataType; /** * Helper methods to create logical transforms to pass into Spark. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala index 84d951cf6e7d4..a85cad35ac6fc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/DateTimeUtils.scala @@ -23,8 +23,8 @@ import java.time.ZoneId import java.util.{Calendar, Locale, TimeZone} import java.util.concurrent.ConcurrentHashMap import java.util.function.{Function => JFunction} - import javax.xml.bind.DatatypeConverter + import scala.annotation.tailrec import org.apache.spark.unsafe.types.UTF8String From d49a179b450bd2b7ed4e3c619480350afdef5702 Mon Sep 17 00:00:00 2001 From: Gengliang Wang Date: Tue, 19 Mar 2019 13:35:47 +0800 Subject: [PATCH 32/70] [SPARK-27162][SQL] Add new method asCaseSensitiveMap in CaseInsensitiveStringMap Currently, DataFrameReader/DataFrameReader supports setting Hadoop configurations via method `.option()`. E.g, the following test case should be passed in both ORC V1 and V2 ``` class TestFileFilter extends PathFilter { override def accept(path: Path): Boolean = path.getParent.getName != "p=2" } withTempPath { dir => val path = dir.getCanonicalPath val df = spark.range(2) df.write.orc(path + "/p=1") df.write.orc(path + "/p=2") val extraOptions = Map( "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName ) assert(spark.read.options(extraOptions).orc(path).count() === 2) } } ``` While Hadoop Configurations are case sensitive, the current data source V2 APIs are using `CaseInsensitiveStringMap` in the top level entry `TableProvider`. To create Hadoop configurations correctly, I suggest 1. adding a new method `asCaseSensitiveMap` in `CaseInsensitiveStringMap`. 2. Make `CaseInsensitiveStringMap` read-only to ambiguous conversion in `asCaseSensitiveMap` Unit test Closes #24094 from gengliangwang/originalMap. Authored-by: Gengliang Wang Signed-off-by: Wenchen Fan --- .../apache/spark/sql/catalog/v2/Catalogs.java | 5 ++- .../sql/util/CaseInsensitiveStringMap.java | 37 +++++++++++++++---- .../util/CaseInsensitiveStringMapSuite.scala | 31 ++++++++++++++-- .../execution/datasources/v2/FileTable.scala | 7 ++-- .../datasources/v2/FileWriteBuilder.scala | 7 ++-- .../datasources/v2/orc/OrcScanBuilder.scala | 6 ++- .../orc/OrcPartitionDiscoverySuite.scala | 23 ++++++++++++ 7 files changed, 96 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java index bcb1f56789daf..851a6a9f6d165 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/Catalogs.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.apache.spark.util.Utils; +import java.util.HashMap; import java.util.Map; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -98,7 +99,7 @@ private static CaseInsensitiveStringMap catalogOptions(String name, SQLConf conf Map allConfs = mapAsJavaMapConverter(conf.getAllConfs()).asJava(); Pattern prefix = Pattern.compile("^spark\\.sql\\.catalog\\." + name + "\\.(.+)"); - CaseInsensitiveStringMap options = CaseInsensitiveStringMap.empty(); + HashMap options = new HashMap<>(); for (Map.Entry entry : allConfs.entrySet()) { Matcher matcher = prefix.matcher(entry.getKey()); if (matcher.matches() && matcher.groupCount() > 0) { @@ -106,6 +107,6 @@ private static CaseInsensitiveStringMap catalogOptions(String name, SQLConf conf } } - return options; + return new CaseInsensitiveStringMap(options); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java index 704d90ed60adc..da41346d7ce71 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/util/CaseInsensitiveStringMap.java @@ -18,8 +18,11 @@ package org.apache.spark.sql.util; import org.apache.spark.annotation.Experimental; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import java.util.Collection; +import java.util.Collections; import java.util.HashMap; import java.util.Locale; import java.util.Map; @@ -35,16 +38,29 @@ */ @Experimental public class CaseInsensitiveStringMap implements Map { + private final Logger logger = LoggerFactory.getLogger(CaseInsensitiveStringMap.class); + + private String unsupportedOperationMsg = "CaseInsensitiveStringMap is read-only."; public static CaseInsensitiveStringMap empty() { return new CaseInsensitiveStringMap(new HashMap<>(0)); } + private final Map original; + private final Map delegate; public CaseInsensitiveStringMap(Map originalMap) { - this.delegate = new HashMap<>(originalMap.size()); - putAll(originalMap); + original = new HashMap<>(originalMap); + delegate = new HashMap<>(originalMap.size()); + for (Map.Entry entry : originalMap.entrySet()) { + String key = toLowerCase(entry.getKey()); + if (delegate.containsKey(key)) { + logger.warn("Converting duplicated key " + entry.getKey() + + " into CaseInsensitiveStringMap."); + } + delegate.put(key, entry.getValue()); + } } @Override @@ -78,24 +94,22 @@ public String get(Object key) { @Override public String put(String key, String value) { - return delegate.put(toLowerCase(key), value); + throw new UnsupportedOperationException(unsupportedOperationMsg); } @Override public String remove(Object key) { - return delegate.remove(toLowerCase(key)); + throw new UnsupportedOperationException(unsupportedOperationMsg); } @Override public void putAll(Map m) { - for (Map.Entry entry : m.entrySet()) { - put(entry.getKey(), entry.getValue()); - } + throw new UnsupportedOperationException(unsupportedOperationMsg); } @Override public void clear() { - delegate.clear(); + throw new UnsupportedOperationException(unsupportedOperationMsg); } @Override @@ -157,4 +171,11 @@ public double getDouble(String key, double defaultValue) { String value = get(key); return value == null ? defaultValue : Double.parseDouble(value); } + + /** + * Returns the original case-sensitive map. + */ + public Map asCaseSensitiveMap() { + return Collections.unmodifiableMap(original); + } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala index 623ddeb140254..0accb471cada3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/CaseInsensitiveStringMapSuite.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.util +import java.util + import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite @@ -25,9 +27,16 @@ class CaseInsensitiveStringMapSuite extends SparkFunSuite { test("put and get") { val options = CaseInsensitiveStringMap.empty() - options.put("kEy", "valUE") - assert(options.get("key") == "valUE") - assert(options.get("KEY") == "valUE") + intercept[UnsupportedOperationException] { + options.put("kEy", "valUE") + } + } + + test("clear") { + val options = new CaseInsensitiveStringMap(Map("kEy" -> "valUE").asJava) + intercept[UnsupportedOperationException] { + options.clear() + } } test("key and value set") { @@ -80,4 +89,20 @@ class CaseInsensitiveStringMapSuite extends SparkFunSuite { options.getDouble("foo", 0.1d) } } + + test("asCaseSensitiveMap") { + val originalMap = new util.HashMap[String, String] { + put("Foo", "Bar") + put("OFO", "ABR") + put("OoF", "bar") + } + + val options = new CaseInsensitiveStringMap(originalMap) + val caseSensitiveMap = options.asCaseSensitiveMap + assert(caseSensitiveMap.equals(originalMap)) + // The result of `asCaseSensitiveMap` is read-only. + intercept[UnsupportedOperationException] { + caseSensitiveMap.put("kEy", "valUE") + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala index deada9a83964b..9cf292782ffe0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileTable.scala @@ -40,15 +40,16 @@ abstract class FileTable( import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ lazy val fileIndex: PartitioningAwareFileIndex = { - val scalaMap = options.asScala.toMap - val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(scalaMap) + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) // This is an internal config so must be present. val checkFilesExist = options.get("check_files_exist").toBoolean val rootPathsSpecified = DataSource.checkAndGlobPathIfNecessary(paths, hadoopConf, checkEmptyGlobPath = true, checkFilesExist = checkFilesExist) val fileStatusCache = FileStatusCache.getOrCreate(sparkSession) new InMemoryFileIndex( - sparkSession, rootPathsSpecified, scalaMap, userSpecifiedSchema, fileStatusCache) + sparkSession, rootPathsSpecified, caseSensitiveMap, userSpecifiedSchema, fileStatusCache) } lazy val dataSchema: StructType = userSpecifiedSchema.orElse { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index e16ee4c460f39..c19a58034484a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -63,15 +63,16 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St validateInputs() val path = new Path(paths.head) val sparkSession = SparkSession.active - val optionsAsScala = options.asScala.toMap - val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(optionsAsScala) + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) val job = getJobInstance(hadoopConf, path) val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, outputPath = paths.head) lazy val description = - createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, optionsAsScala) + createWriteJobDescription(sparkSession, hadoopConf, job, paths.head, options.asScala.toMap) val fs = path.getFileSystem(hadoopConf) mode match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala index 0b153416b7bb0..a2c55e8c43021 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcScanBuilder.scala @@ -36,7 +36,11 @@ case class OrcScanBuilder( schema: StructType, dataSchema: StructType, options: CaseInsensitiveStringMap) extends FileScanBuilder(schema) { - lazy val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(options.asScala.toMap) + lazy val hadoopConf = { + val caseSensitiveMap = options.asCaseSensitiveMap.asScala.toMap + // Hadoop Configurations are case sensitive. + sparkSession.sessionState.newHadoopConfWithOptions(caseSensitiveMap) + } override def build(): Scan = { OrcScan(sparkSession, hadoopConf, fileIndex, dataSchema, readSchema) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala index 4a695ac74c476..b4d92c3b2d2fa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionDiscoverySuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution.datasources.orc import java.io.File +import org.apache.hadoop.fs.{Path, PathFilter} + import org.apache.spark.SparkConf import org.apache.spark.sql._ import org.apache.spark.sql.internal.SQLConf @@ -30,6 +32,10 @@ case class OrcParData(intField: Int, stringField: String) // The data that also includes the partitioning key case class OrcParDataWithKey(intField: Int, pi: Int, stringField: String, ps: String) +class TestFileFilter extends PathFilter { + override def accept(path: Path): Boolean = path.getParent.getName != "p=2" +} + abstract class OrcPartitionDiscoveryTest extends OrcTest { val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__" @@ -226,6 +232,23 @@ abstract class OrcPartitionDiscoveryTest extends OrcTest { } } } + + test("SPARK-27162: handle pathfilter configuration correctly") { + withTempPath { dir => + val path = dir.getCanonicalPath + + val df = spark.range(2) + df.write.orc(path + "/p=1") + df.write.orc(path + "/p=2") + assert(spark.read.orc(path).count() === 4) + + val extraOptions = Map( + "mapred.input.pathFilter.class" -> classOf[TestFileFilter].getName, + "mapreduce.input.pathFilter.class" -> classOf[TestFileFilter].getName + ) + assert(spark.read.options(extraOptions).orc(path).count() === 2) + } + } } class OrcPartitionDiscoverySuite extends OrcPartitionDiscoveryTest with SharedSQLContext From affb14b268f9b33bac8a3e02749ce67cd80855b8 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 15:46:23 -0700 Subject: [PATCH 33/70] Fix compilation --- .../main/scala/org/apache/spark/sql/execution/HiveResult.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index a369b49777a5f..41cebc247a186 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -21,7 +21,7 @@ import java.nio.charset.StandardCharsets import java.sql.{Date, Timestamp} import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, TimestampFormatter} +import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.{DescribeCommandBase, ExecutedCommandExec, ShowTablesCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ From 46616714322fdcfcb7ef72a36186a7ed2245aa16 Mon Sep 17 00:00:00 2001 From: mcheah Date: Wed, 15 May 2019 15:50:37 -0700 Subject: [PATCH 34/70] More Scala 2.11 stuff --- .../spark/sql/sources/v2/TestInMemoryTableCatalog.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala index 2ecf1c2f184fb..42c2db2539060 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/TestInMemoryTableCatalog.scala @@ -134,7 +134,9 @@ private class InMemoryTable( TableCapability.BATCH_READ, TableCapability.BATCH_WRITE, TableCapability.TRUNCATE).asJava override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { - () => new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) + new ScanBuilder() { + def build(): Scan = new InMemoryBatchScan(data.map(_.asInstanceOf[InputPartition])) + } } class InMemoryBatchScan(data: Array[InputPartition]) extends Scan with Batch { From ee834f7ab6539b33d058922da0d51e239cd54aa0 Mon Sep 17 00:00:00 2001 From: Hyukjin Kwon Date: Mon, 18 Feb 2019 21:13:00 +0800 Subject: [PATCH 35/70] [SPARK-26744][SPARK-26744][SQL][HOTFOX] Disable schema validation tests for FileDataSourceV2 (partially revert ) ## What changes were proposed in this pull request? This PR partially revert SPARK-26744. https://github.com/apache/spark/commit/60caa92deaf6941f58da82dcc0962ebf3a598ced and https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 were merged at similar time range independently. So the test failures were not caught. - https://github.com/apache/spark/commit/60caa92deaf6941f58da82dcc0962ebf3a598ced happened to add a schema reading logic in writing path for overwrite mode as well. - https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 added some tests with overwrite modes with migrated ORC v2. And the tests looks starting to fail. I guess the discussion won't be short (see https://github.com/apache/spark/pull/23606#discussion_r257675083) and this PR proposes to disable the tests added at https://github.com/apache/spark/commit/4dce45a5992e6a89a26b5a0739b33cfeaf979208 to unblock other PRs for now. ## How was this patch tested? Existing tests. Closes #23828 from HyukjinKwon/SPARK-26744. Authored-by: Hyukjin Kwon Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala index 591884095ec38..e46802f69ed67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/FileBasedDataSourceSuite.scala @@ -334,7 +334,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo test("SPARK-24204 error handling for unsupported Interval data types - csv, json, parquet, orc") { withTempDir { dir => val tempDir = new File(dir, "files").getCanonicalPath - Seq(true, false).foreach { useV1 => + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc" } else { @@ -379,7 +379,7 @@ class FileBasedDataSourceSuite extends QueryTest with SharedSQLContext with Befo } test("SPARK-24204 error handling for unsupported Null data types - csv, parquet, orc") { - Seq(true, false).foreach { useV1 => + Seq(true).foreach { useV1 => val useV1List = if (useV1) { "orc" } else { From 5e7eb1251e96f65326dec18170e7f72f990140ca Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Apr 2019 10:31:27 +0800 Subject: [PATCH 36/70] [SPARK-26811][SQL][FOLLOWUP] fix some documentation ## What changes were proposed in this pull request? It's a followup of https://github.com/apache/spark/pull/24012 , to fix 2 documentation: 1. `SupportsRead` and `SupportsWrite` are not internal anymore. They are public interfaces now. 2. `Scan` should link the `BATCH_READ` instead of hardcoding it. ## How was this patch tested? N/A Closes #24285 from cloud-fan/doc. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/sources/v2/SupportsRead.java | 2 +- .../org/apache/spark/sql/sources/v2/SupportsWrite.java | 2 +- .../java/org/apache/spark/sql/sources/v2/reader/Scan.java | 8 +++++--- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java index 67fc72e070dc9..826fa2f8a0720 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * An internal base interface of mix-in interfaces for readable {@link Table}. This adds + * A mix-in interface of {@link Table}, to indicate that it's readable. This adds * {@link #newScanBuilder(CaseInsensitiveStringMap)} that is used to create a scan for batch, * micro-batch, or continuous processing. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java index b215963868217..c52e54569dc0c 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java @@ -22,7 +22,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap; /** - * An internal base interface of mix-in interfaces for writable {@link Table}. This adds + * A mix-in interface of {@link Table}, to indicate that it's writable. This adds * {@link #newWriteBuilder(CaseInsensitiveStringMap)} that is used to create a write * for batch or streaming. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index e97d0548c66ff..7633d504d36b1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -24,6 +24,7 @@ import org.apache.spark.sql.sources.v2.SupportsContinuousRead; import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; +import org.apache.spark.sql.sources.v2.TableCapability; /** * A logical representation of a data source scan. This interface is used to provide logical @@ -32,8 +33,8 @@ * This logical representation is shared between batch scan, micro-batch streaming scan and * continuous streaming scan. Data sources must implement the corresponding methods in this * interface, to match what the table promises to support. For example, {@link #toBatch()} must be - * implemented, if the {@link Table} that creates this {@link Scan} returns BATCH_READ support in - * its {@link Table#capabilities()}. + * implemented, if the {@link Table} that creates this {@link Scan} returns + * {@link TableCapability#BATCH_READ} support in its {@link Table#capabilities()}. *

*/ @Evolving @@ -61,7 +62,8 @@ default String description() { /** * Returns the physical representation of this scan for batch query. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this returns batch read support in its {@link Table#capabilities()}. + * {@link Table} that creates this scan returns {@link TableCapability#BATCH_READ} in its + * {@link Table#capabilities()}. * * @throws UnsupportedOperationException */ From 57153b476eb78a4587d83dd340638e5d5230dba1 Mon Sep 17 00:00:00 2001 From: uncleGen Date: Sat, 27 Apr 2019 09:28:31 +0800 Subject: [PATCH 37/70] [MINOR][TEST][DOC] Execute action miss name message ## What changes were proposed in this pull request? some minor updates: - `Execute` action miss `name` message - typo in SS document - typo in SQLConf ## How was this patch tested? N/A Closes #24466 from uncleGen/minor-fix. Authored-by: uncleGen Signed-off-by: Wenchen Fan --- docs/structured-streaming-programming-guide.md | 2 +- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../src/main/scala/org/apache/spark/sql/DataFrameReader.scala | 2 +- .../src/main/scala/org/apache/spark/sql/DataFrameWriter.scala | 2 +- .../sql/execution/datasources/DataSourceResolution.scala | 2 +- .../scala/org/apache/spark/sql/streaming/StreamTest.scala | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/docs/structured-streaming-programming-guide.md b/docs/structured-streaming-programming-guide.md index e76b53dbb4dc3..77db1c3d7d613 100644 --- a/docs/structured-streaming-programming-guide.md +++ b/docs/structured-streaming-programming-guide.md @@ -2980,7 +2980,7 @@ the effect of the change is not well-defined. For all of them: - Changes to the user-defined foreach sink (that is, the `ForeachWriter` code) are allowed, but the semantics of the change depends on the code. -- *Changes in projection / filter / map-like operations**: Some cases are allowed. For example: +- *Changes in projection / filter / map-like operations*: Some cases are allowed. For example: - Addition / deletion of filters is allowed: `sdf.selectExpr("a")` to `sdf.where(...).selectExpr("a").filter(...)`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 58ad84d029024..759c22189d874 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -2096,9 +2096,9 @@ class SQLConf extends Serializable with Logging { def continuousStreamingExecutorPollIntervalMs: Long = getConf(CONTINUOUS_STREAMING_EXECUTOR_POLL_INTERVAL_MS) - def userV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) + def useV1SourceReaderList: String = getConf(USE_V1_SOURCE_READER_LIST) - def userV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) + def useV1SourceWriterList: String = getConf(USE_V1_SOURCE_WRITER_LIST) def disabledV2StreamingWriters: String = getConf(DISABLED_V2_STREAMING_WRITERS) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala index b6d347e6415a3..0cf9957539e73 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameReader.scala @@ -195,7 +195,7 @@ class DataFrameReader private[sql](sparkSession: SparkSession) extends Logging { } val useV1Sources = - sparkSession.sessionState.conf.userV1SourceReaderList.toLowerCase(Locale.ROOT).split(",") + sparkSession.sessionState.conf.useV1SourceReaderList.toLowerCase(Locale.ROOT).split(",") val lookupCls = DataSource.lookupDataSource(source, sparkSession.sessionState.conf) val cls = lookupCls.newInstance() match { case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 289efde2a7f00..228830b56f3aa 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -247,7 +247,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val session = df.sparkSession val useV1Sources = - session.sessionState.conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") + session.sessionState.conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",") val lookupCls = DataSource.lookupDataSource(source, session.sessionState.conf) val cls = lookupCls.newInstance() match { case f: FileDataSourceV2 if useV1Sources.contains(f.shortName()) || diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 6d1cbe18c900c..09506f05ccfa4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -76,7 +76,7 @@ case class DataSourceResolution( object V1WriteProvider { private val v1WriteOverrideSet = - conf.userV1SourceWriterList.toLowerCase(Locale.ROOT).split(",").toSet + conf.useV1SourceWriterList.toLowerCase(Locale.ROOT).split(",").toSet def unapply(provider: String): Option[String] = { if (v1WriteOverrideSet.contains(provider.toLowerCase(Locale.ROOT))) { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index da496837e7a19..a8efe5b4e889e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -294,7 +294,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be /** Execute arbitrary code */ object Execute { def apply(name: String)(func: StreamExecution => Any): AssertOnQuery = - AssertOnQuery(query => { func(query); true }, "name") + AssertOnQuery(query => { func(query); true }, name) def apply(func: StreamExecution => Any): AssertOnQuery = apply("Execute")(func) } From 0c2d6aaf7ca80966d26c8bb8ceb42d610e23e196 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 16 May 2019 16:24:53 -0700 Subject: [PATCH 38/70] [SPARK-27576][SQL] table capability to skip the output column resolution Currently we have an analyzer rule, which resolves the output columns of data source v2 writing plans, to make sure the schema of input query is compatible with the table. However, not all data sources need this check. For example, the `NoopDataSource` doesn't care about the schema of input query at all. This PR introduces a new table capability: ACCEPT_ANY_SCHEMA. If a table reports this capability, we skip resolving output columns for it during write. Note that, we already skip resolving output columns for `NoopDataSource` because it implements `SupportsSaveMode`. However, `SupportsSaveMode` is a hack and will be removed soon. new test cases Closes #24469 from cloud-fan/schema-check. Authored-by: Wenchen Fan Signed-off-by: Dongjoon Hyun --- .../spark/sql/sources/v2/TableCapability.java | 7 ++++- .../sql/catalyst/analysis/NamedRelation.scala | 3 ++ .../plans/logical/basicLogicalOperators.scala | 9 ++++-- .../analysis/DataSourceV2AnalysisSuite.scala | 31 +++++++++++++++++-- .../datasources/noop/NoopDataSource.scala | 3 +- .../datasources/v2/DataSourceV2Relation.scala | 2 ++ 6 files changed, 48 insertions(+), 7 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java index 5a9b85e6d0361..33c3d647bf409 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -65,5 +65,10 @@ public enum TableCapability { *

* See {@code org.apache.spark.sql.sources.v2.writer.SupportsDynamicOverwrite}. */ - OVERWRITE_DYNAMIC + OVERWRITE_DYNAMIC, + + /** + * Signals that the table accepts input of any schema in a write operation. + */ + ACCEPT_ANY_SCHEMA } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala index ad201f947b671..56b8d84441c95 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/NamedRelation.scala @@ -21,4 +21,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan trait NamedRelation extends LogicalPlan { def name: String + + // When false, the schema of input data must match the schema of this relation, during write. + def skipSchemaResolution: Boolean = false } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 2bbe0bb006897..9636f2fec38c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -378,14 +378,17 @@ trait V2WriteCommand extends Command { override lazy val resolved: Boolean = outputResolved def outputResolved: Boolean = { - table.resolved && query.resolved && query.output.size == table.output.size && + // If the table doesn't require schema match, we don't need to resolve the output columns. + table.skipSchemaResolution || { + table.resolved && query.resolved && query.output.size == table.output.size && query.output.zip(table.output).forall { case (inAttr, outAttr) => // names and types must match, nullability must be compatible inAttr.name == outAttr.name && - DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && - (outAttr.nullable || !inAttr.nullable) + DataType.equalsIgnoreCompatibleNullability(outAttr.dataType, inAttr.dataType) && + (outAttr.nullable || !inAttr.nullable) } + } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala index 0c48548614266..48b43fcccacef 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/DataSourceV2AnalysisSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.catalyst.analysis import java.util.Locale import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, Cast, Expression, LessThanOrEqual, Literal} -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LeafNode, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Project} -import org.apache.spark.sql.types.{DoubleType, FloatType, StructField, StructType} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types._ class V2AppendDataAnalysisSuite extends DataSourceV2AnalysisSuite { override def byName(table: NamedRelation, query: LogicalPlan): LogicalPlan = { @@ -104,6 +104,12 @@ case class TestRelation(output: Seq[AttributeReference]) extends LeafNode with N override def name: String = "table-name" } +case class TestRelationAcceptAnySchema(output: Seq[AttributeReference]) + extends LeafNode with NamedRelation { + override def name: String = "test-name" + override def skipSchemaResolution: Boolean = true +} + abstract class DataSourceV2AnalysisSuite extends AnalysisTest { val table = TestRelation(StructType(Seq( StructField("x", FloatType), @@ -446,6 +452,27 @@ abstract class DataSourceV2AnalysisSuite extends AnalysisTest { "Cannot safely cast", "'x'", "DoubleType to FloatType")) } + test("bypass output column resolution") { + val table = TestRelationAcceptAnySchema(StructType(Seq( + StructField("a", FloatType, nullable = false), + StructField("b", DoubleType))).toAttributes) + + val query = TestRelation(StructType(Seq( + StructField("s", StringType))).toAttributes) + + withClue("byName") { + val parsedPlan = byName(table, query) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + + withClue("byPosition") { + val parsedPlan = byPosition(table, query) + assertResolved(parsedPlan) + checkAnalysis(parsedPlan, parsedPlan) + } + } + def assertNotResolved(logicalPlan: LogicalPlan): Unit = { assert(!logicalPlan.resolved, s"Plan should not be resolved: $logicalPlan") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 96a78d3a0da20..5e20480c3f272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -43,7 +43,8 @@ private[noop] object NoopTable extends Table with SupportsWrite with SupportsStr override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() - override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava + override def capabilities(): util.Set[TableCapability] = Set( + TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } private[noop] object NoopWriteBuilder extends WriteBuilder diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala index 411995718603c..fc919439d9224 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Relation.scala @@ -44,6 +44,8 @@ case class DataSourceV2Relation( override def name: String = table.name() + override def skipSchemaResolution: Boolean = table.supports(TableCapability.ACCEPT_ANY_SCHEMA) + override def simpleString(maxFields: Int): String = { s"RelationV2${truncatedString(output, "[", ", ", "]", maxFields)} $name" } From d3e9b9472a69572c752c32ecce4ee7f36cefea5c Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 24 May 2019 10:45:46 -0700 Subject: [PATCH 39/70] [SPARK-26356][SQL] remove SaveMode from data source v2 In data source v1, save mode specified in `DataFrameWriter` is passed to data source implementation directly, and each data source can define its own behavior about save mode. This is confusing and we want to get rid of save mode in data source v2. For data source v2, we expect data source to implement the `TableCatalog` API, and end-users use SQL(or the new write API described in [this doc](https://docs.google.com/document/d/1gYm5Ji2Mge3QBdOliFV5gSPTKlX4q1DCBXIkiyMv62A/edit?ts=5ace0718#heading=h.e9v1af12g5zo)) to acess data sources. The SQL API has very clear semantic and we don't need save mode at all. However, for simple data sources that do not have table management (like a JIRA data source, a noop sink, etc.), it's not ideal to ask them to implement the `TableCatalog` API, and throw exception here and there. `TableProvider` API is created for simple data sources. It can only get tables, without any other table management methods. This means, it can only deal with existing tables. `TableProvider` fits well with `DataStreamReader` and `DataStreamWriter`, as they can only read/write existing tables. However, `TableProvider` doesn't fit `DataFrameWriter` well, as the save mode requires more than just get table. More specifically, `ErrorIfExists` mode needs to check if table exists, and create table. `Ignore` mode needs to check if table exists. When end-users specify `ErrorIfExists` or `Ignore` mode and write data to `TableProvider` via `DataFrameWriter`, Spark fails the query and asks users to use `Append` or `Overwrite` mode. The file source is in the middle of `TableProvider` and `TableCatalog`: it's simple but it can check table(path) exists and create table(path). That said, file source supports all the save modes. Currently file source implements `TableProvider`, and it's not working because `TableProvider` doesn't support `ErrorIfExists` and `Ignore` modes. Ideally we should create a new API for path-based data sources, but to unblock the work of file source v2 migration, this PR proposes to special-case file source v2 in `DataFrameWriter`, to make it work. This PR also removes `SaveMode` from data source v2, as now only the internal file source v2 needs it. existing tests Closes #24233 from cloud-fan/file. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../spark/sql/sources/v2/TableProvider.java | 4 + .../sources/v2/writer/SupportsSaveMode.java | 26 ------ .../sql/sources/v2/writer/WriteBuilder.java | 8 +- .../apache/spark/sql/DataFrameWriter.scala | 80 ++++++++++--------- .../datasources/noop/NoopDataSource.scala | 7 +- .../datasources/v2/FileWriteBuilder.scala | 11 ++- .../v2/WriteToDataSourceV2Exec.scala | 35 ++------ .../sql/sources/v2/DataSourceV2Suite.scala | 51 ++++++------ .../v2/FileDataSourceV2FallBackSuite.scala | 4 +- .../sources/v2/SimpleWritableDataSource.scala | 23 ++---- .../sql/test/DataFrameReaderWriterSuite.scala | 73 ++++++++++++++++- 11 files changed, 166 insertions(+), 156 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 04ad8fd90be9f..0e2eb9c3cabb7 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -26,6 +26,10 @@ * The base interface for v2 data sources which don't have a real catalog. Implementations must * have a public, 0-arg constructor. *

+ * Note that, TableProvider can only apply data operations to existing tables, like read, append, + * delete, and overwrite. It does not support the operations that require metadata changes, like + * create/drop tables. + *

* The major responsibility of this interface is to return a {@link Table} for read/write. *

*/ diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java deleted file mode 100644 index c4295f2371877..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsSaveMode.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2.writer; - -import org.apache.spark.sql.SaveMode; - -// A temporary mixin trait for `WriteBuilder` to support `SaveMode`. Will be removed before -// Spark 3.0 when all the new write operators are finished. See SPARK-26356 for more details. -public interface SupportsSaveMode extends WriteBuilder { - WriteBuilder mode(SaveMode mode); -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java index e08d34fbf453e..bfe41f5e8dfb5 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java +++ b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java @@ -57,12 +57,8 @@ default WriteBuilder withInputDataSchema(StructType schema) { /** * Returns a {@link BatchWrite} to write data to batch source. By default this method throws * exception, data sources must overwrite this method to provide an implementation, if the - * {@link Table} that creates this write returns BATCH_WRITE support in its - * {@link Table#capabilities()}. - * - * Note that, the returned {@link BatchWrite} can be null if the implementation supports SaveMode, - * to indicate that no writing is needed. We can clean it up after removing - * {@link SupportsSaveMode}. + * {@link Table} that creates this write returns {@link TableCapability#BATCH_WRITE} support in + * its {@link Table#capabilities()}. */ default BatchWrite buildForBatch() { throw new UnsupportedOperationException(getClass().getName() + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala index 228830b56f3aa..b87b3bd4f0761 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrameWriter.scala @@ -29,12 +29,12 @@ import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.plans.logical.{AppendData, InsertIntoTable, LogicalPlan, OverwriteByExpression} import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.command.DDLUtils -import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, LogicalRelation} -import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, DataSourceV2Utils, FileDataSourceV2, WriteToDataSourceV2} +import org.apache.spark.sql.execution.datasources.{CreateTable, DataSource, DataSourceUtils, LogicalRelation} +import org.apache.spark.sql.execution.datasources.v2._ +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ -import org.apache.spark.sql.sources.v2.writer.SupportsSaveMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -55,13 +55,16 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { *
  • `SaveMode.Overwrite`: overwrite the existing data.
  • *
  • `SaveMode.Append`: append the data.
  • *
  • `SaveMode.Ignore`: ignore the operation (i.e. no-op).
  • - *
  • `SaveMode.ErrorIfExists`: default option, throw an exception at runtime.
  • + *
  • `SaveMode.ErrorIfExists`: throw an exception at runtime.
  • * + *

    + * When writing to data source v1, the default option is `ErrorIfExists`. When writing to data + * source v2, the default option is `Append`. * * @since 1.4.0 */ def mode(saveMode: SaveMode): DataFrameWriter[T] = { - this.mode = saveMode + this.mode = Some(saveMode) this } @@ -77,15 +80,15 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { * @since 1.4.0 */ def mode(saveMode: String): DataFrameWriter[T] = { - this.mode = saveMode.toLowerCase(Locale.ROOT) match { - case "overwrite" => SaveMode.Overwrite - case "append" => SaveMode.Append - case "ignore" => SaveMode.Ignore - case "error" | "errorifexists" | "default" => SaveMode.ErrorIfExists + saveMode.toLowerCase(Locale.ROOT) match { + case "overwrite" => mode(SaveMode.Overwrite) + case "append" => mode(SaveMode.Append) + case "ignore" => mode(SaveMode.Ignore) + case "error" | "errorifexists" => mode(SaveMode.ErrorIfExists) + case "default" => this case _ => throw new IllegalArgumentException(s"Unknown save mode: $saveMode. " + "Accepted save modes are 'overwrite', 'append', 'ignore', 'error', 'errorifexists'.") } - this } /** @@ -269,9 +272,24 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { + // TODO (SPARK-27815): To not break existing tests, here we treat file source as a special + // case, and pass the save mode to file source directly. This hack should be removed. + case table: FileTable => + val write = table.newWriteBuilder(dsOptions).asInstanceOf[FileWriteBuilder] + .mode(modeForDSV1) // should not change default mode for file source. + .withQueryId(UUID.randomUUID().toString) + .withInputDataSchema(df.logicalPlan.schema) + .buildForBatch() + // The returned `Write` can be null, which indicates that we can skip writing. + if (write != null) { + runCommand(df.sparkSession, "save") { + WriteToDataSourceV2(write, df.logicalPlan) + } + } + case table: SupportsWrite if table.supports(BATCH_WRITE) => lazy val relation = DataSourceV2Relation.create(table, dsOptions) - mode match { + modeForDSV2 match { case SaveMode.Append => runCommand(df.sparkSession, "save") { AppendData.byName(relation, df.logicalPlan) @@ -283,25 +301,10 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { OverwriteByExpression.byName(relation, df.logicalPlan, Literal(true)) } - case _ => - table.newWriteBuilder(dsOptions) match { - case writeBuilder: SupportsSaveMode => - val write = writeBuilder.mode(mode) - .withQueryId(UUID.randomUUID().toString) - .withInputDataSchema(df.logicalPlan.schema) - .buildForBatch() - // It can only return null with `SupportsSaveMode`. We can clean it up after - // removing `SupportsSaveMode`. - if (write != null) { - runCommand(df.sparkSession, "save") { - WriteToDataSourceV2(write, df.logicalPlan) - } - } - - case _ => - throw new AnalysisException( - s"data source ${table.name} does not support SaveMode $mode") - } + case other => + throw new AnalysisException(s"TableProvider implementation $source cannot be " + + s"written with $other mode, please use Append or Overwrite " + + "modes instead.") } // Streaming also uses the data source V2 API. So it may be that the data source implements @@ -321,7 +324,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { sparkSession = df.sparkSession, className = source, partitionColumns = partitioningColumns.getOrElse(Nil), - options = extraOptions.toMap).planForWriting(mode, df.logicalPlan) + options = extraOptions.toMap).planForWriting(modeForDSV1, df.logicalPlan) } } @@ -370,7 +373,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { table = UnresolvedRelation(tableIdent), partition = Map.empty[String, Option[String]], query = df.logicalPlan, - overwrite = mode == SaveMode.Overwrite, + overwrite = modeForDSV1 == SaveMode.Overwrite, ifPartitionNotExists = false) } } @@ -450,7 +453,7 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { val tableIdentWithDB = tableIdent.copy(database = Some(db)) val tableName = tableIdentWithDB.unquotedString - (tableExists, mode) match { + (tableExists, modeForDSV1) match { case (true, SaveMode.Ignore) => // Do nothing @@ -505,7 +508,8 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { partitionColumnNames = partitioningColumns.getOrElse(Nil), bucketSpec = getBucketSpec) - runCommand(df.sparkSession, "saveAsTable")(CreateTable(tableDesc, mode, Some(df.logicalPlan))) + runCommand(df.sparkSession, "saveAsTable")( + CreateTable(tableDesc, modeForDSV1, Some(df.logicalPlan))) } /** @@ -711,13 +715,17 @@ final class DataFrameWriter[T] private[sql](ds: Dataset[T]) { SQLExecution.withNewExecutionId(session, qe, Some(name))(qe.toRdd) } + private def modeForDSV1 = mode.getOrElse(SaveMode.ErrorIfExists) + + private def modeForDSV2 = mode.getOrElse(SaveMode.Append) + /////////////////////////////////////////////////////////////////////////////////////// // Builder pattern config options /////////////////////////////////////////////////////////////////////////////////////// private var source: String = df.sparkSession.sessionState.conf.defaultDataSourceName - private var mode: SaveMode = SaveMode.ErrorIfExists + private var mode: Option[SaveMode] = None private val extraOptions = new scala.collection.mutable.HashMap[String, String] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 5e20480c3f272..dc89e1a78afde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -21,7 +21,6 @@ import java.util import scala.collection.JavaConverters._ -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ @@ -44,12 +43,10 @@ private[noop] object NoopTable extends Table with SupportsWrite with SupportsStr override def name(): String = "noop-table" override def schema(): StructType = new StructType() override def capabilities(): util.Set[TableCapability] = Set( - TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava + TableCapability.BATCH_WRITE, TableCapability.TRUNCATE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } -private[noop] object NoopWriteBuilder extends WriteBuilder - with SupportsSaveMode with SupportsTruncate { - override def mode(mode: SaveMode): WriteBuilder = this +private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { override def truncate(): WriteBuilder = this override def buildForBatch(): BatchWrite = NoopBatchWrite override def buildForStreaming(): StreamingWrite = NoopStreamingWrite diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index c19a58034484a..f133dec47b87a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -33,13 +33,16 @@ import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, DateTimeUtils} import org.apache.spark.sql.execution.datasources.{BasicWriteJobStatsTracker, DataSource, OutputWriterFactory, WriteJobDescription} import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, SupportsSaveMode, WriteBuilder} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, WriteBuilder} import org.apache.spark.sql.types.{DataType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.SerializableConfiguration -abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) - extends WriteBuilder with SupportsSaveMode { +abstract class FileWriteBuilder( + options: CaseInsensitiveStringMap, + paths: Seq[String], + formatName: String, + supportsDataType: DataType => Boolean) extends WriteBuilder { private var schema: StructType = _ private var queryId: String = _ private var mode: SaveMode = _ @@ -54,7 +57,7 @@ abstract class FileWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[St this } - override def mode(mode: SaveMode): WriteBuilder = { + def mode(mode: SaveMode): WriteBuilder = { this.mode = mode this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 1797166bbe0b0..6c771ea988324 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -26,7 +26,6 @@ import org.apache.spark.{SparkEnv, SparkException, TaskContext} import org.apache.spark.executor.CommitDeniedException import org.apache.spark.internal.Logging import org.apache.spark.rdd.RDD -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.InternalRow @@ -36,7 +35,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.sources.{AlwaysTrue, Filter} import org.apache.spark.sql.sources.v2.SupportsWrite -import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsSaveMode, SupportsTruncate, WriteBuilder, WriterCommitMessage} +import org.apache.spark.sql.sources.v2.writer.{BatchWrite, DataWriterFactory, SupportsDynamicOverwrite, SupportsOverwrite, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.{LongAccumulator, Utils} @@ -81,16 +80,10 @@ case class CreateTableAsSelectExec( Utils.tryWithSafeFinallyAndFailureCallbacks({ catalog.createTable(ident, query.schema, partitioning.toArray, properties.asJava) match { case table: SupportsWrite => - val builder = table.newWriteBuilder(writeOptions) - .withInputDataSchema(query.schema) - .withQueryId(UUID.randomUUID().toString) - val batchWrite = builder match { - case supportsSaveMode: SupportsSaveMode => - supportsSaveMode.mode(SaveMode.Append).buildForBatch() - - case _ => - builder.buildForBatch() - } + val batchWrite = table.newWriteBuilder(writeOptions) + .withInputDataSchema(query.schema) + .withQueryId(UUID.randomUUID().toString) + .buildForBatch() doWrite(batchWrite) @@ -116,13 +109,7 @@ case class AppendDataExec( query: SparkPlan) extends V2TableWriteExec with BatchWriteHelper { override protected def doExecute(): RDD[InternalRow] = { - val batchWrite = newWriteBuilder() match { - case builder: SupportsSaveMode => - builder.mode(SaveMode.Append).buildForBatch() - - case builder => - builder.buildForBatch() - } + val batchWrite = newWriteBuilder().buildForBatch() doWrite(batchWrite) } } @@ -152,9 +139,6 @@ case class OverwriteByExpressionExec( case builder: SupportsTruncate if isTruncate(deleteWhere) => builder.truncate().buildForBatch() - case builder: SupportsSaveMode if isTruncate(deleteWhere) => - builder.mode(SaveMode.Overwrite).buildForBatch() - case builder: SupportsOverwrite => builder.overwrite(deleteWhere).buildForBatch() @@ -185,9 +169,6 @@ case class OverwritePartitionsDynamicExec( case builder: SupportsDynamicOverwrite => builder.overwriteDynamicPartitions().buildForBatch() - case builder: SupportsSaveMode => - builder.mode(SaveMode.Overwrite).buildForBatch() - case _ => throw new SparkException(s"Table does not support dynamic partition overwrite: $table") } @@ -350,8 +331,8 @@ object DataWritingSparkTask extends Logging { } private[v2] case class DataWritingSparkTaskResult( - numRows: Long, - writerCommitMessage: WriterCommitMessage) + numRows: Long, + writerCommitMessage: WriterCommitMessage) /** * Sink progress information collected after commit. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 587cfa9bd6647..379c9c4303cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -26,7 +26,7 @@ import scala.collection.JavaConverters._ import test.org.apache.spark.sql.sources.v2._ import org.apache.spark.SparkException -import org.apache.spark.sql.{DataFrame, QueryTest, Row} +import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasources.v2.{BatchScanExec, DataSourceV2Relation} import org.apache.spark.sql.execution.exchange.{Exchange, ShuffleExchangeExec} @@ -219,14 +219,14 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).save() + .option("path", path).mode("append").save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).select('id, -'id)) - // test with different save modes + // default save mode is append spark.range(10).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("append").save() + .option("path", path).save() checkAnswer( spark.read.format(cls.getName).option("path", path).load(), spark.range(10).union(spark.range(10)).select('id, -'id)) @@ -237,17 +237,17 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { spark.read.format(cls.getName).option("path", path).load(), spark.range(5).select('id, -'id)) - spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) - .option("path", path).mode("ignore").save() - checkAnswer( - spark.read.format(cls.getName).option("path", path).load(), - spark.range(5).select('id, -'id)) + val e = intercept[AnalysisException] { + spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) + .option("path", path).mode("ignore").save() + } + assert(e.message.contains("please use Append or Overwrite modes instead")) - val e = intercept[Exception] { + val e2 = intercept[AnalysisException] { spark.range(5).select('id as 'i, -'id as 'j).write.format(cls.getName) .option("path", path).mode("error").save() } - assert(e.getMessage.contains("data already exists")) + assert(e2.getMessage.contains("please use Append or Overwrite modes instead")) // test transaction val failingUdf = org.apache.spark.sql.functions.udf { @@ -262,10 +262,10 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } // this input data will fail to read middle way. val input = spark.range(10).select(failingUdf('id).as('i)).select('i, -'i as 'j) - val e2 = intercept[SparkException] { + val e3 = intercept[SparkException] { input.write.format(cls.getName).option("path", path).mode("overwrite").save() } - assert(e2.getMessage.contains("Writing job aborted")) + assert(e3.getMessage.contains("Writing job aborted")) // make sure we don't have partial data. assert(spark.read.format(cls.getName).option("path", path).load().collect().isEmpty) } @@ -375,21 +375,16 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { } } - test("SPARK-25700: do not read schema when writing in other modes except append and overwrite") { - withTempPath { file => - val cls = classOf[SimpleWriteOnlyDataSource] - val path = file.getCanonicalPath - val df = spark.range(5).select('id as 'i, -'id as 'j) - // non-append mode should not throw exception, as they don't access schema. - df.write.format(cls.getName).option("path", path).mode("error").save() - df.write.format(cls.getName).option("path", path).mode("ignore").save() - // append and overwrite modes will access the schema and should throw exception. - intercept[SchemaReadAttemptException] { - df.write.format(cls.getName).option("path", path).mode("append").save() - } - intercept[SchemaReadAttemptException] { - df.write.format(cls.getName).option("path", path).mode("overwrite").save() - } + test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { + withTempView("t1") { + val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + Seq(2, 3).toDF("a").createTempView("t1") + val df = t2.where("i < (select max(a) from t1)").select('i) + val subqueries = df.queryExecution.executedPlan.collect { + case p => p.subqueries + }.flatten + assert(subqueries.length == 1) + checkAnswer(df, (0 until 3).map(i => Row(i))) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala index e019dbfe3f512..e84c082128e1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/FileDataSourceV2FallBackSuite.scala @@ -50,7 +50,7 @@ class DummyReadOnlyFileTable extends Table with SupportsRead { } override def capabilities(): java.util.Set[TableCapability] = - Set(TableCapability.BATCH_READ).asJava + Set(TableCapability.BATCH_READ, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class DummyWriteOnlyFileDataSourceV2 extends FileDataSourceV2 { @@ -73,7 +73,7 @@ class DummyWriteOnlyFileTable extends Table with SupportsWrite { throw new AnalysisException("Dummy file writer") override def capabilities(): java.util.Set[TableCapability] = - Set(TableCapability.BATCH_WRITE).asJava + Set(TableCapability.BATCH_WRITE, TableCapability.ACCEPT_ANY_SCHEMA).asJava } class FileDataSourceV2FallBackSuite extends QueryTest with SharedSQLContext { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala index edebb0b62b29c..c9d2f1eef24bb 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/SimpleWritableDataSource.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.SparkContext -import org.apache.spark.sql.SaveMode import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ @@ -70,38 +69,26 @@ class SimpleWritableDataSource extends TableProvider with SessionConfigSupport { override def readSchema(): StructType = tableSchema } - class MyWriteBuilder(path: String) extends WriteBuilder with SupportsSaveMode { + class MyWriteBuilder(path: String) extends WriteBuilder with SupportsTruncate { private var queryId: String = _ - private var mode: SaveMode = _ + private var needTruncate = false override def withQueryId(queryId: String): WriteBuilder = { this.queryId = queryId this } - override def mode(mode: SaveMode): WriteBuilder = { - this.mode = mode + override def truncate(): WriteBuilder = { + this.needTruncate = true this } override def buildForBatch(): BatchWrite = { - assert(mode != null) - val hadoopPath = new Path(path) val hadoopConf = SparkContext.getActive.get.hadoopConfiguration val fs = hadoopPath.getFileSystem(hadoopConf) - if (mode == SaveMode.ErrorIfExists) { - if (fs.exists(hadoopPath)) { - throw new RuntimeException("data already exists.") - } - } - if (mode == SaveMode.Ignore) { - if (fs.exists(hadoopPath)) { - return null - } - } - if (mode == SaveMode.Overwrite) { + if (needTruncate) { fs.delete(hadoopPath, true) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala index e45ab19aadbfa..a388de1970f14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/DataFrameReaderWriterSuite.scala @@ -38,10 +38,15 @@ import org.apache.spark.internal.io.HadoopMapReduceCommitProtocol import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, LogicalPlan, OverwriteByExpression} +import org.apache.spark.sql.execution.QueryExecution +import org.apache.spark.sql.execution.datasources.DataSourceUtils +import org.apache.spark.sql.execution.datasources.noop.NoopDataSource import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.QueryExecutionListener import org.apache.spark.util.Utils @@ -220,15 +225,75 @@ class DataFrameReaderWriterSuite extends QueryTest with SharedSQLContext with Be } test("save mode") { - val df = spark.read + spark.range(10).write .format("org.apache.spark.sql.test") - .load() + .mode(SaveMode.ErrorIfExists) + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) - df.write + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode(SaveMode.Append) + .save() + assert(LastOptions.saveMode === SaveMode.Append) + + // By default the save mode is `ErrorIfExists` for data source v1. + spark.range(10).write .format("org.apache.spark.sql.test") - .mode(SaveMode.ErrorIfExists) .save() assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + + spark.range(10).write + .format("org.apache.spark.sql.test") + .mode("default") + .save() + assert(LastOptions.saveMode === SaveMode.ErrorIfExists) + } + + test("save mode for data source v2") { + var plan: LogicalPlan = null + val listener = new QueryExecutionListener { + override def onSuccess(funcName: String, qe: QueryExecution, durationNs: Long): Unit = { + plan = qe.analyzed + + } + override def onFailure(funcName: String, qe: QueryExecution, exception: Exception): Unit = {} + } + + spark.listenerManager.register(listener) + try { + // append mode creates `AppendData` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Append) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + // overwrite mode creates `OverwriteByExpression` + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode(SaveMode.Overwrite) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[OverwriteByExpression]) + + // By default the save mode is `ErrorIfExists` for data source v2. + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + + spark.range(10).write + .format(classOf[NoopDataSource].getName) + .mode("default") + .save() + sparkContext.listenerBus.waitUntilEmpty(1000) + assert(plan.isInstanceOf[AppendData]) + } finally { + spark.listenerManager.unregister(listener) + } } test("test path option in load") { From c0ffa903c2150197f8fa1b86d8d171e84c438ba2 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 24 May 2019 14:46:02 -0700 Subject: [PATCH 40/70] Fix compilation issues --- .../sql/execution/datasources/v2/FileWriteBuilder.scala | 7 ++++--- .../sql/execution/datasources/v2/orc/OrcWriteBuilder.scala | 6 +++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala index f133dec47b87a..5375d965d1eff 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/FileWriteBuilder.scala @@ -41,7 +41,7 @@ import org.apache.spark.util.SerializableConfiguration abstract class FileWriteBuilder( options: CaseInsensitiveStringMap, paths: Seq[String], - formatName: String, + _formatName: String, supportsDataType: DataType => Boolean) extends WriteBuilder { private var schema: StructType = _ private var queryId: String = _ @@ -133,9 +133,10 @@ abstract class FileWriteBuilder( assert(paths.length == 1) DataSource.validateSchema(schema) schema.foreach { field => - if (!supportsDataType(field.dataType)) { + if (!supportsDataType.apply(field.dataType)) { throw new AnalysisException( - s"$formatName data source does not support ${field.dataType.catalogString} data type.") + s"$formatName data source does not support ${field.dataType.catalogString}" + + s" data type.") } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala index 829ab5fbe1768..b1f8b8916a390 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcWriteBuilder.scala @@ -29,7 +29,11 @@ import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap class OrcWriteBuilder(options: CaseInsensitiveStringMap, paths: Seq[String]) - extends FileWriteBuilder(options, paths) { + extends FileWriteBuilder( + options, + paths, + "orc", + supportsDataType = OrcDataSourceV2.supportsDataType) { override def prepareWrite( sqlConf: SQLConf, From c96838822903684eab67898e2bdc867e4315b9a8 Mon Sep 17 00:00:00 2001 From: mcheah Date: Fri, 24 May 2019 14:55:52 -0700 Subject: [PATCH 41/70] Fix scalastyle --- .../spark/sql/execution/datasources/noop/NoopDataSource.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index dc89e1a78afde..1da41f2baefcb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -43,7 +43,9 @@ private[noop] object NoopTable extends Table with SupportsWrite with SupportsStr override def name(): String = "noop-table" override def schema(): StructType = new StructType() override def capabilities(): util.Set[TableCapability] = Set( - TableCapability.BATCH_WRITE, TableCapability.TRUNCATE, TableCapability.ACCEPT_ANY_SCHEMA).asJava + TableCapability.BATCH_WRITE, + TableCapability.TRUNCATE, + TableCapability.ACCEPT_ANY_SCHEMA).asJava } private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { From d7e39433d4c117c9842a08761869fd9d0281486e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 5 Jun 2019 09:55:55 -0700 Subject: [PATCH 42/70] [SPARK-27521][SQL] Move data source v2 to catalyst module Currently we are in a strange status that, some data source v2 interfaces(catalog related) are in sql/catalyst, some data source v2 interfaces(Table, ScanBuilder, DataReader, etc.) are in sql/core. I don't see a reason to keep data source v2 API in 2 modules. If we should pick one module, I think sql/catalyst is the one to go. Catalyst module already has some user-facing stuff like DataType, Row, etc. And we have to update `Analyzer` and `SessionCatalog` to support the new catalog plugin, which needs to be in the catalyst module. This PR can solve the problem we have in https://github.com/apache/spark/pull/24246 existing tests Closes #24416 from cloud-fan/move. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- project/MimaExcludes.scala | 121 +++++++++++ sql/catalyst/pom.xml | 4 + .../sql/sources/v2/SessionConfigSupport.java | 0 .../spark/sql/sources/v2/SupportsRead.java | 0 .../spark/sql/sources/v2/SupportsWrite.java | 0 .../spark/sql/sources/v2/TableProvider.java | 9 +- .../spark/sql/sources/v2/reader/Batch.java | 0 .../sql/sources/v2/reader/InputPartition.java | 0 .../sources/v2/reader/PartitionReader.java | 0 .../v2/reader/PartitionReaderFactory.java | 0 .../spark/sql/sources/v2/reader/Scan.java | 0 .../sql/sources/v2/reader/ScanBuilder.java | 0 .../sql/sources/v2/reader/Statistics.java | 0 .../v2/reader/SupportsPushDownFilters.java | 0 .../SupportsPushDownRequiredColumns.java | 0 .../v2/reader/SupportsReportPartitioning.java | 0 .../v2/reader/SupportsReportStatistics.java | 0 .../partitioning/ClusteredDistribution.java | 0 .../v2/reader/partitioning/Distribution.java | 0 .../v2/reader/partitioning/Partitioning.java | 0 .../streaming/ContinuousPartitionReader.java | 0 .../ContinuousPartitionReaderFactory.java | 0 .../v2/reader/streaming/ContinuousStream.java | 0 .../v2/reader/streaming/MicroBatchStream.java | 0 .../sources/v2/reader/streaming/Offset.java | 0 .../v2/reader/streaming/PartitionOffset.java | 0 .../v2/reader/streaming/SparkDataStream.java | 0 .../sql/sources/v2/writer/BatchWrite.java | 0 .../sql/sources/v2/writer/DataWriter.java | 0 .../sources/v2/writer/DataWriterFactory.java | 0 .../v2/writer/SupportsDynamicOverwrite.java | 0 .../sources/v2/writer/SupportsOverwrite.java | 0 .../sources/v2/writer/SupportsTruncate.java | 0 .../sql/sources/v2/writer/WriteBuilder.java | 0 .../v2/writer/WriterCommitMessage.java | 0 .../streaming/StreamingDataWriterFactory.java | 0 .../v2/writer/streaming/StreamingWrite.java | 0 .../sql/vectorized/ArrowColumnVector.java | 2 +- .../spark/sql/vectorized/ColumnVector.java | 0 .../spark/sql/vectorized/ColumnarArray.java | 0 .../spark/sql/vectorized/ColumnarBatch.java | 0 .../spark/sql/vectorized/ColumnarMap.java | 0 .../spark/sql/vectorized/ColumnarRow.java | 0 .../apache/spark/sql/sources/filters.scala | 0 .../apache/spark/sql/util}/ArrowUtils.scala | 2 +- .../spark/sql/util}/ArrowUtilsSuite.scala | 4 +- sql/core/pom.xml | 4 - .../sql/execution/arrow/ArrowConverters.scala | 1 + .../sql/execution/arrow/ArrowWriter.scala | 1 + .../python/AggregateInPandasExec.scala | 2 +- .../python/ArrowEvalPythonExec.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 3 +- .../python/FlatMapGroupsInPandasExec.scala | 4 +- .../execution/python/WindowInPandasExec.scala | 2 +- .../spark/sql/execution/r/ArrowRRunner.scala | 201 ++++++++++++++++++ .../arrow/ArrowConvertersSuite.scala | 1 + .../sources/RateStreamProviderSuite.scala | 2 +- .../sources/TextSocketStreamSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 4 +- 60 files changed, 347 insertions(+), 26 deletions(-) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java (89%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java (99%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java (100%) rename sql/{core => catalyst}/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java (100%) rename sql/{core => catalyst}/src/main/scala/org/apache/spark/sql/sources/filters.scala (100%) rename sql/{core/src/main/scala/org/apache/spark/sql/execution/arrow => catalyst/src/main/scala/org/apache/spark/sql/util}/ArrowUtils.scala (99%) rename sql/{core/src/test/scala/org/apache/spark/sql/execution/arrow => catalyst/src/test/scala/org/apache/spark/sql/util}/ArrowUtilsSuite.scala (96%) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 76744af2327c5..c73c35dc0c905 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -165,7 +165,128 @@ object MimaExcludes { case ReversedMissingMethodProblem(meth) => !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") case _ => true +<<<<<<< HEAD } +||||||| parent of 8b6232b119... [SPARK-27521][SQL] Move data source v2 to catalyst module + }, + + // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + + // [SPARK-26616][MLlib] Expose document frequency in IDFModel + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") +======= + }, + + // [SPARK-27521][SQL] Move data source v2 to catalyst module + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarBatch"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ArrowColumnVector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarRow"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarArray"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnarMap"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.vectorized.ColumnVector"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThanOrEqual"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.In"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualNullSafe"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringEndsWith$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThanOrEqual$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Not$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.LessThan$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNotNull$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.GreaterThan"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Filter"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.IsNull"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.EqualTo"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.And$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.Or$"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringStartsWith"), + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.StringContains$"), + + // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) + ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), + ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), + ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), + + // [SPARK-11215][ML] Add multiple columns support to StringIndexer + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), + + // [SPARK-26616][MLlib] Expose document frequency in IDFModel + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") +>>>>>>> 8b6232b119... [SPARK-27521][SQL] Move data source v2 to catalyst module ) // Exclude rules for 2.4.x diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 16ecebf159c1f..323032fbfd1f9 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -109,6 +109,10 @@ 2.7.3 jar + + org.apache.arrow + arrow-vector + target/scala-${scala.binary.version}/classes diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SessionConfigSupport.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsRead.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/SupportsWrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java similarity index 89% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java index 0e2eb9c3cabb7..1d37ff042bd33 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableProvider.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.DataSourceRegister; import org.apache.spark.sql.types.StructType; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -56,13 +55,7 @@ public interface TableProvider { * @throws UnsupportedOperationException */ default Table getTable(CaseInsensitiveStringMap options, StructType schema) { - String name; - if (this instanceof DataSourceRegister) { - name = ((DataSourceRegister) this).shortName(); - } else { - name = this.getClass().getName(); - } throw new UnsupportedOperationException( - name + " source does not support user-specified schema"); + this.getClass().getSimpleName() + " source does not support user-specified schema"); } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Batch.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/InputPartition.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReader.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/PartitionReaderFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/ScanBuilder.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Statistics.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownFilters.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsPushDownRequiredColumns.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportStatistics.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/ClusteredDistribution.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Distribution.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/partitioning/Partitioning.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReader.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousPartitionReaderFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/ContinuousStream.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/MicroBatchStream.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/PartitionOffset.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/BatchWrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriter.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/DataWriterFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsDynamicOverwrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsOverwrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/SupportsTruncate.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriteBuilder.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/WriterCommitMessage.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingDataWriterFactory.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/writer/streaming/StreamingWrite.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 99% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index 906e9bc26ef53..07d17ee14ce23 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -23,7 +23,7 @@ import org.apache.arrow.vector.holders.NullableVarCharHolder; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.arrow.ArrowUtils; +import org.apache.spark.sql.util.ArrowUtils; import org.apache.spark.sql.types.*; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarMap.java diff --git a/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 100% rename from sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/sources/filters.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala similarity index 99% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 7de6256aef084..62546a322d3c9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.arrow +package org.apache.spark.sql.util import scala.collection.JavaConverters._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala similarity index 96% rename from sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala index d801f62b62323..4439a7bb3ae87 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/ArrowUtilsSuite.scala @@ -15,9 +15,9 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.arrow +package org.apache.spark.sql.util -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.arrow.vector.types.pojo.ArrowType import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.util.DateTimeUtils diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 95e98c5444721..6f0db3632d7dd 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -112,10 +112,6 @@ com.fasterxml.jackson.core jackson-databind - - org.apache.arrow - arrow-vector - org.apache.xbean xbean-asm7-shaded diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index 2bf6a58b55658..4b692aaeb1e63 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -35,6 +35,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.{ByteBufferOutputStream, Utils} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala index 8dd484af6e908..6147d6fefd52a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowWriter.scala @@ -25,6 +25,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.SpecializedGetters import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils object ArrowWriter { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala index 2ab7240556aaa..0c78cca086ed3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/AggregateInPandasExec.scala @@ -28,8 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala index a5203daea9cd0..d1105f0382f6f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowEvalPythonExec.scala @@ -25,8 +25,8 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, UnaryNode} import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils /** * Grouped a iterator into batches. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 04623b1ab3c2f..3710218b2af5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -29,8 +29,9 @@ import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} +import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index e9cff1a5a2007..18b074b807807 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -27,8 +27,10 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala index 1ce1215bfdd62..01ce07b133ffd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/WindowInPandasExec.scala @@ -29,9 +29,9 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{ExternalAppendOnlyUnsafeRowArray, SparkPlan} -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.execution.window._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala new file mode 100644 index 0000000000000..0fe2b628fa38b --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala @@ -0,0 +1,201 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.r + +import java.io._ +import java.nio.channels.Channels + +import scala.collection.JavaConverters._ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel + +import org.apache.spark.{SparkException, TaskContext} +import org.apache.spark.api.r._ +import org.apache.spark.api.r.SpecialLengths +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.util.Utils + + +/** + * Similar to `ArrowPythonRunner`, but exchange data with R worker via Arrow stream. + */ +class ArrowRRunner( + func: Array[Byte], + packageNames: Array[Byte], + broadcastVars: Array[Broadcast[Object]], + schema: StructType, + timeZoneId: String, + mode: Int) + extends BaseRRunner[Iterator[InternalRow], ColumnarBatch]( + func, + "arrow", + "arrow", + packageNames, + broadcastVars, + numPartitions = -1, + isDataFrame = true, + schema.fieldNames, + mode) { + + protected def bufferedWrite( + dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = { + val out = new ByteArrayOutputStream() + writeFunc(out) + + // Currently, there looks no way to read batch by batch by socket connection in R side, + // See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at + // once for now. + val data = out.toByteArray + dataOut.writeInt(data.length) + dataOut.write(data) + } + + protected override def newWriterThread( + output: OutputStream, + inputIterator: Iterator[Iterator[InternalRow]], + partitionIndex: Int): WriterThread = { + new WriterThread(output, inputIterator, partitionIndex) { + + /** + * Writes input data to the stream connected to the R worker. + */ + override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + if (inputIterator.hasNext) { + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "stdout writer for R", 0, Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + bufferedWrite(dataOut) { out => + Utils.tryWithSafeFinally { + val arrowWriter = ArrowWriter.create(root) + val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) + writer.start() + + while (inputIterator.hasNext) { + val nextBatch: Iterator[InternalRow] = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + } + writer.end() + } { + // Don't close root and allocator in TaskCompletionListener to prevent + // a race condition. See `ArrowPythonRunner`. + root.close() + allocator.close() + } + } + } + } + } + } + + protected override def newReaderIterator( + dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = { + new ReaderIterator(dataStream, errThread) { + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + "stdin reader for R", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var vectors: Array[ColumnVector] = _ + + TaskContext.get().addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def read(): ColumnarBatch = try { + if (reader != null && batchLoaded) { + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + batch.setNumRows(root.getRowCount) + batch + } else { + reader.close(false) + allocator.close() + // Should read timing data after this. + read() + } + } else { + dataStream.readInt() match { + case SpecialLengths.TIMING_DATA => + // Timing data from R worker + val boot = dataStream.readDouble - bootTime + val init = dataStream.readDouble + val broadcast = dataStream.readDouble + val input = dataStream.readDouble + val compute = dataStream.readDouble + val output = dataStream.readDouble + logInfo( + ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + + "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + + "total = %.3f s").format( + boot, + init, + broadcast, + input, + compute, + output, + boot + init + broadcast + input + compute + output)) + read() + case length if length > 0 => + // Likewise, there looks no way to send each batch in streaming format via socket + // connection. See ARROW-4512. + // So, it reads the whole Arrow streaming-formatted binary at once for now. + val buffer = new Array[Byte](length) + dataStream.readFully(buffer) + val in = new ByteArrayReadableSeekableByteChannel(buffer) + reader = new ArrowStreamReader(in, allocator) + root = reader.getVectorSchemaRoot + vectors = root.getFieldVectors.asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() + case length if length == 0 => + // End of stream + eos = true + null + } + } + } catch { + case eof: EOFException => + throw new SparkException( + "R worker exited unexpectedly (crashed)\n " + errThread.getLines(), eof) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala index c36872a6a5289..86874b9817c20 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowConvertersSuite.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{BinaryType, Decimal, IntegerType, StructField, StructType} +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index c04f6e3f255cb..b024f957020a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -305,7 +305,7 @@ class RateStreamProviderSuite extends StreamTest { .load() } assert(exception.getMessage.contains( - "rate source does not support user-specified schema")) + "RateStreamProvider source does not support user-specified schema")) } test("continuous data") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 6a7c54176c347..956339355de48 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -204,7 +204,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before provider.getTable(new CaseInsensitiveStringMap(params.asJava), userSpecifiedSchema) } assert(exception.getMessage.contains( - "socket source does not support user-specified schema")) + "TextSocketSourceProvider source does not support user-specified schema")) } test("input row metrics") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 4592a1663faed..60f1b32a41f05 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -21,8 +21,8 @@ import org.apache.arrow.vector._ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index e8062dbb91e35..4dd65385d548b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -31,9 +31,9 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.memory.MemoryMode import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval From e0edb6c9d908facce63f6b627d2b47f2f53b30a3 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 13:41:41 -0700 Subject: [PATCH 43/70] Fix merge conflicts --- project/MimaExcludes.scala | 44 ---- .../spark/sql/execution/r/ArrowRRunner.scala | 201 ------------------ 2 files changed, 245 deletions(-) delete mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index c73c35dc0c905..85127f1fb5c9a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -165,49 +165,6 @@ object MimaExcludes { case ReversedMissingMethodProblem(meth) => !meth.owner.fullName.startsWith("org.apache.spark.sql.sources.v2") case _ => true -<<<<<<< HEAD - } -||||||| parent of 8b6232b119... [SPARK-27521][SQL] Move data source v2 to catalyst module - }, - - // [SPARK-26216][SQL] Do not use case class as public API (UserDefinedFunction) - ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction$"), - ProblemFilters.exclude[AbstractClassProblem]("org.apache.spark.sql.expressions.UserDefinedFunction"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.inputTypes"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullableTypes_="), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.dataType"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.f"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.this"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNonNullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.nullable"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.asNondeterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.deterministic"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.apply"), - ProblemFilters.exclude[DirectAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[ReversedAbstractMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.withName"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productElement"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productArity"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$2"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.canEqual"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$1"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productIterator"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.productPrefix"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.expressions.UserDefinedFunction.copy$default$3"), - - // [SPARK-11215][ML] Add multiple columns support to StringIndexer - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexer.validateAndTransformSchema"), - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.StringIndexerModel.validateAndTransformSchema"), - - // [SPARK-26616][MLlib] Expose document frequency in IDFModel - ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), - ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") -======= }, // [SPARK-27521][SQL] Move data source v2 to catalyst module @@ -286,7 +243,6 @@ object MimaExcludes { // [SPARK-26616][MLlib] Expose document frequency in IDFModel ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.mllib.feature.IDFModel.this"), ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.mllib.feature.IDF#DocumentFrequencyAggregator.idf") ->>>>>>> 8b6232b119... [SPARK-27521][SQL] Move data source v2 to catalyst module ) // Exclude rules for 2.4.x diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala deleted file mode 100644 index 0fe2b628fa38b..0000000000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/r/ArrowRRunner.scala +++ /dev/null @@ -1,201 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.r - -import java.io._ -import java.nio.channels.Channels - -import scala.collection.JavaConverters._ - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.{ArrowStreamReader, ArrowStreamWriter} -import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel - -import org.apache.spark.{SparkException, TaskContext} -import org.apache.spark.api.r._ -import org.apache.spark.api.r.SpecialLengths -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} -import org.apache.spark.util.Utils - - -/** - * Similar to `ArrowPythonRunner`, but exchange data with R worker via Arrow stream. - */ -class ArrowRRunner( - func: Array[Byte], - packageNames: Array[Byte], - broadcastVars: Array[Broadcast[Object]], - schema: StructType, - timeZoneId: String, - mode: Int) - extends BaseRRunner[Iterator[InternalRow], ColumnarBatch]( - func, - "arrow", - "arrow", - packageNames, - broadcastVars, - numPartitions = -1, - isDataFrame = true, - schema.fieldNames, - mode) { - - protected def bufferedWrite( - dataOut: DataOutputStream)(writeFunc: ByteArrayOutputStream => Unit): Unit = { - val out = new ByteArrayOutputStream() - writeFunc(out) - - // Currently, there looks no way to read batch by batch by socket connection in R side, - // See ARROW-4512. Therefore, it writes the whole Arrow streaming-formatted binary at - // once for now. - val data = out.toByteArray - dataOut.writeInt(data.length) - dataOut.write(data) - } - - protected override def newWriterThread( - output: OutputStream, - inputIterator: Iterator[Iterator[InternalRow]], - partitionIndex: Int): WriterThread = { - new WriterThread(output, inputIterator, partitionIndex) { - - /** - * Writes input data to the stream connected to the R worker. - */ - override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - if (inputIterator.hasNext) { - val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( - "stdout writer for R", 0, Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - bufferedWrite(dataOut) { out => - Utils.tryWithSafeFinally { - val arrowWriter = ArrowWriter.create(root) - val writer = new ArrowStreamWriter(root, null, Channels.newChannel(out)) - writer.start() - - while (inputIterator.hasNext) { - val nextBatch: Iterator[InternalRow] = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } - - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - } - writer.end() - } { - // Don't close root and allocator in TaskCompletionListener to prevent - // a race condition. See `ArrowPythonRunner`. - root.close() - allocator.close() - } - } - } - } - } - } - - protected override def newReaderIterator( - dataStream: DataInputStream, errThread: BufferedStreamThread): ReaderIterator = { - new ReaderIterator(dataStream, errThread) { - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - "stdin reader for R", 0, Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var vectors: Array[ColumnVector] = _ - - TaskContext.get().addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def read(): ColumnarBatch = try { - if (reader != null && batchLoaded) { - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - batch.setNumRows(root.getRowCount) - batch - } else { - reader.close(false) - allocator.close() - // Should read timing data after this. - read() - } - } else { - dataStream.readInt() match { - case SpecialLengths.TIMING_DATA => - // Timing data from R worker - val boot = dataStream.readDouble - bootTime - val init = dataStream.readDouble - val broadcast = dataStream.readDouble - val input = dataStream.readDouble - val compute = dataStream.readDouble - val output = dataStream.readDouble - logInfo( - ("Times: boot = %.3f s, init = %.3f s, broadcast = %.3f s, " + - "read-input = %.3f s, compute = %.3f s, write-output = %.3f s, " + - "total = %.3f s").format( - boot, - init, - broadcast, - input, - compute, - output, - boot + init + broadcast + input + compute + output)) - read() - case length if length > 0 => - // Likewise, there looks no way to send each batch in streaming format via socket - // connection. See ARROW-4512. - // So, it reads the whole Arrow streaming-formatted binary at once for now. - val buffer = new Array[Byte](length) - dataStream.readFully(buffer) - val in = new ByteArrayReadableSeekableByteChannel(buffer) - reader = new ArrowStreamReader(in, allocator) - root = reader.getVectorSchemaRoot - vectors = root.getFieldVectors.asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] - read() - case length if length == 0 => - // End of stream - eos = true - null - } - } - } catch { - case eof: EOFException => - throw new SparkException( - "R worker exited unexpectedly (crashed)\n " + errThread.getLines(), eof) - } - } - } -} From c7c5d84be451f24cd3b784cb407bd9d456f7c670 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Fri, 24 May 2019 11:13:22 +0800 Subject: [PATCH 44/70] [SPARK-27732][SQL] Add v2 CreateTable implementation. ## What changes were proposed in this pull request? This adds a v2 implementation of create table: * `CreateV2Table` is the logical plan, named using v2 to avoid conflicting with the existing plan * `CreateTableExec` is the physical plan ## How was this patch tested? Added resolution and v2 SQL tests. Closes #24617 from rdblue/SPARK-27732-add-v2-create-table. Authored-by: Ryan Blue Signed-off-by: Wenchen Fan --- .../plans/logical/basicLogicalOperators.scala | 11 ++ .../datasources/DataSourceResolution.scala | 86 ++++++++++----- .../datasources/v2/CreateTableExec.scala | 56 ++++++++++ .../datasources/v2/DataSourceV2Strategy.scala | 5 +- .../command/PlanResolutionSuite.scala | 81 +++++++++++++- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 104 ++++++++++++++++++ 6 files changed, 315 insertions(+), 28 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 9636f2fec38c0..7c2b047cdd3cc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -392,6 +392,17 @@ trait V2WriteCommand extends Command { } } +/** + * Create a new table with a v2 catalog. + */ +case class CreateV2Table( + catalog: TableCatalog, + tableName: Identifier, + tableSchema: StructType, + partitioning: Seq[Transform], + properties: Map[String, String], + ignoreIfExists: Boolean) extends Command + /** * Create a new table from a select query with a v2 catalog. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 09506f05ccfa4..c8f90b84679f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, LogicalPlan} import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.internal.SQLConf @@ -54,6 +54,15 @@ case class DataSourceResolution( CreateTable(tableDesc, mode, None) + case create: CreateTableStatement => + // the provider was not a v1 source, convert to a v2 plan + val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName + val catalog = maybeCatalog.orElse(defaultCatalog) + .getOrElse(throw new AnalysisException( + s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) + .asTableCatalog + convertCreateTable(catalog, identifier, create) + case CreateTableAsSelectStatement( AsTableIdentifier(table), query, partitionCols, bucketSpec, properties, V1WriteProvider(provider), options, location, comment, ifNotExists) => @@ -135,49 +144,76 @@ case class DataSourceResolution( catalog: TableCatalog, identifier: Identifier, ctas: CreateTableAsSelectStatement): CreateTableAsSelect = { - if (ctas.options.contains("path") && ctas.location.isDefined) { + // convert the bucket spec and add it as a transform + val partitioning = ctas.partitioning ++ ctas.bucketSpec.map(_.asTransform) + val properties = convertTableProperties( + ctas.properties, ctas.options, ctas.location, ctas.comment, ctas.provider) + + CreateTableAsSelect( + catalog, + identifier, + partitioning, + ctas.asSelect, + properties, + writeOptions = ctas.options.filterKeys(_ != "path"), + ignoreIfExists = ctas.ifNotExists) + } + + private def convertCreateTable( + catalog: TableCatalog, + identifier: Identifier, + create: CreateTableStatement): CreateV2Table = { + // convert the bucket spec and add it as a transform + val partitioning = create.partitioning ++ create.bucketSpec.map(_.asTransform) + val properties = convertTableProperties( + create.properties, create.options, create.location, create.comment, create.provider) + + CreateV2Table( + catalog, + identifier, + create.tableSchema, + partitioning, + properties, + ignoreIfExists = create.ifNotExists) + } + + private def convertTableProperties( + properties: Map[String, String], + options: Map[String, String], + location: Option[String], + comment: Option[String], + provider: String): Map[String, String] = { + if (options.contains("path") && location.isDefined) { throw new AnalysisException( "LOCATION and 'path' in OPTIONS are both used to indicate the custom table path, " + "you can only specify one of them.") } - if ((ctas.options.contains("provider") || ctas.properties.contains("provider")) - && ctas.comment.isDefined) { + if ((options.contains("comment") || properties.contains("comment")) + && comment.isDefined) { throw new AnalysisException( "COMMENT and option/property 'comment' are both used to set the table comment, you can " + "only specify one of them.") } - if (ctas.options.contains("provider") || ctas.properties.contains("provider")) { + if (options.contains("provider") || properties.contains("provider")) { throw new AnalysisException( "USING and option/property 'provider' are both used to set the provider implementation, " + "you can only specify one of them.") } - val options = ctas.options.filterKeys(_ != "path") - - // convert the bucket spec and add it as a transform - val partitioning = ctas.partitioning ++ ctas.bucketSpec.map(_.asTransform) + val filteredOptions = options.filterKeys(_ != "path") // create table properties from TBLPROPERTIES and OPTIONS clauses - val properties = new mutable.HashMap[String, String]() - properties ++= ctas.properties - properties ++= options + val tableProperties = new mutable.HashMap[String, String]() + tableProperties ++= properties + tableProperties ++= filteredOptions // convert USING, LOCATION, and COMMENT clauses to table properties - properties += ("provider" -> ctas.provider) - ctas.comment.map(text => properties += ("comment" -> text)) - ctas.location - .orElse(ctas.options.get("path")) - .map(loc => properties += ("location" -> loc)) + tableProperties += ("provider" -> provider) + comment.map(text => tableProperties += ("comment" -> text)) + location.orElse(options.get("path")).map(loc => tableProperties += ("location" -> loc)) - CreateTableAsSelect( - catalog, - identifier, - partitioning, - ctas.asSelect, - properties.toMap, - writeOptions = options, - ignoreIfExists = ctas.ifNotExists) + tableProperties.toMap } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala new file mode 100644 index 0000000000000..f35758bf08c67 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableExec.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import scala.collection.JavaConverters._ + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalog.v2.expressions.Transform +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode +import org.apache.spark.sql.types.StructType + +case class CreateTableExec( + catalog: TableCatalog, + identifier: Identifier, + tableSchema: StructType, + partitioning: Seq[Transform], + tableProperties: Map[String, String], + ignoreIfExists: Boolean) extends LeafExecNode { + import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ + + override protected def doExecute(): RDD[InternalRow] = { + if (!catalog.tableExists(identifier)) { + try { + catalog.createTable(identifier, tableSchema, partitioning.toArray, tableProperties.asJava) + } catch { + case _: TableAlreadyExistsException if ignoreIfExists => + logWarning(s"Table ${identifier.quoted} was created concurrently. Ignoring.") + } + } else if (!ignoreIfExists) { + throw new TableAlreadyExistsException(identifier) + } + + sqlContext.sparkContext.parallelize(Seq.empty, 1) + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 165553c4da5bb..c5cd328a6fee4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -148,6 +148,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case WriteToDataSourceV2(writer, query) => WriteToDataSourceV2Exec(writer, planLater(query)) :: Nil + case CreateV2Table(catalog, ident, schema, parts, props, ifNotExists) => + CreateTableExec(catalog, ident, schema, parts, props, ifNotExists) :: Nil + case CreateTableAsSelect(catalog, ident, parts, query, props, options, ifNotExists) => val writeOptions = new CaseInsensitiveStringMap(options.asJava) CreateTableAsSelectExec( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index c525b4cbcba57..ebd21d8a1d53c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -25,10 +25,10 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, LogicalPlan} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 -import org.apache.spark.sql.types.{IntegerType, StringType, StructType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructType} import org.apache.spark.sql.util.CaseInsensitiveStringMap class PlanResolutionSuite extends AnalysisTest { @@ -295,6 +295,83 @@ class PlanResolutionSuite extends AnalysisTest { } } + test("Test v2 CreateTable with known catalog in identifier") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS testcat.mydb.table_name ( + | id bigint, + | description string, + | point struct) + |USING parquet + |COMMENT 'table comment' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |OPTIONS (path 's3://bucket/path/to/data', other 20) + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "other" -> "20", + "provider" -> "parquet", + "location" -> "s3://bucket/path/to/data", + "comment" -> "table comment") + + parseAndResolve(sql) match { + case create: CreateV2Table => + assert(create.catalog.name == "testcat") + assert(create.tableName == Identifier.of(Array("mydb"), "table_name")) + assert(create.tableSchema == new StructType() + .add("id", LongType) + .add("description", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) + assert(create.partitioning.isEmpty) + assert(create.properties == expectedProperties) + assert(create.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + + test("Test v2 CreateTable with data source v2 provider") { + val sql = + s""" + |CREATE TABLE IF NOT EXISTS mydb.page_view ( + | id bigint, + | description string, + | point struct) + |USING $orc2 + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + """.stripMargin + + val expectedProperties = Map( + "p1" -> "v1", + "p2" -> "v2", + "provider" -> orc2, + "location" -> "/user/external/page_view", + "comment" -> "This is the staging page view table") + + parseAndResolve(sql) match { + case create: CreateV2Table => + assert(create.catalog.name == "testcat") + assert(create.tableName == Identifier.of(Array("mydb"), "page_view")) + assert(create.tableSchema == new StructType() + .add("id", LongType) + .add("description", StringType) + .add("point", new StructType().add("x", DoubleType).add("y", DoubleType))) + assert(create.partitioning.isEmpty) + assert(create.properties == expectedProperties) + assert(create.ignoreIfExists) + + case other => + fail(s"Expected to parse ${classOf[CreateV2Table].getName} from query," + + s"got ${other.getClass.getName}: $sql") + } + } + test("Test v2 CTAS with known catalog in identifier") { val sql = s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index a9bc0369ad20f..606d2ad790b6b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -49,6 +49,110 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn spark.sql("DROP TABLE source") } + test("CreateTable: use v2 plan because catalog is set") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: use v2 plan because provider is v2") { + spark.sql(s"CREATE TABLE table_name (id bigint, data string) USING $orc2") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> orc2).asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: fail if table exists") { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + // run a second create query that should fail + val exc = intercept[TableAlreadyExistsException] { + spark.sql("CREATE TABLE testcat.table_name (id bigint, data string, id2 bigint) USING bar") + } + + assert(exc.getMessage.contains("table_name")) + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType().add("id", LongType).add("data", StringType)) + + // check that the table is still empty + val rdd = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), Seq.empty) + } + + test("CreateTable: if not exists") { + spark.sql( + "CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING foo") + + val testCatalog = spark.catalog("testcat").asTableCatalog + val table = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + + assert(table.name == "testcat.table_name") + assert(table.partitioning.isEmpty) + assert(table.properties == Map("provider" -> "foo").asJava) + assert(table.schema == new StructType().add("id", LongType).add("data", StringType)) + + spark.sql("CREATE TABLE IF NOT EXISTS testcat.table_name (id bigint, data string) USING bar") + + // table should not have changed + val table2 = testCatalog.loadTable(Identifier.of(Array(), "table_name")) + assert(table2.name == "testcat.table_name") + assert(table2.partitioning.isEmpty) + assert(table2.properties == Map("provider" -> "foo").asJava) + assert(table2.schema == new StructType().add("id", LongType).add("data", StringType)) + + // check that the table is still empty + val rdd2 = spark.sparkContext.parallelize(table.asInstanceOf[InMemoryTable].rows) + checkAnswer(spark.internalCreateDataFrame(rdd2, table.schema), Seq.empty) + } + + test("CreateTable: fail analysis when default catalog is needed but missing") { + val originalDefaultCatalog = conf.getConfString("spark.sql.default.catalog") + try { + conf.unsetConf("spark.sql.default.catalog") + + val exc = intercept[AnalysisException] { + spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") + } + + assert(exc.getMessage.contains("No catalog specified for table")) + assert(exc.getMessage.contains("table_name")) + assert(exc.getMessage.contains("no default catalog is set")) + + } finally { + conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) + } + } + test("CreateTableAsSelect: use v2 plan because catalog is set") { spark.sql("CREATE TABLE testcat.table_name USING foo AS SELECT id, data FROM source") From 6244b770ae5b32bc690f8c01058b9bc89b1d83ab Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Thu, 30 May 2019 09:22:42 +0800 Subject: [PATCH 45/70] [SPARK-26946][SQL][FOLLOWUP] Require lookup function ## What changes were proposed in this pull request? Require the lookup function with interface LookupCatalog. Rationale is in the review comments below. Make `Analyzer` abstract. BaseSessionStateBuilder and HiveSessionStateBuilder implements lookupCatalog with a call to SparkSession.catalog(). Existing test cases and those that don't need catalog lookup will use a newly added `TestAnalyzer` with a default lookup function that throws` CatalogNotFoundException("No catalog lookup function")`. Rewrote the unit test for LookupCatalog to demonstrate the interface can be used anywhere, not just Analyzer. Removed Analyzer parameter `lookupCatalog` because we can override in the following manner: ``` new Analyzer() { override def lookupCatalog(name: String): CatalogPlugin = ??? } ``` ## How was this patch tested? Existing unit tests. Closes #24689 from jzhuge/SPARK-26946-follow. Authored-by: John Zhuge Signed-off-by: Wenchen Fan --- .../spark/sql/catalog/v2/LookupCatalog.scala | 28 +++--- .../sql/catalyst/analysis/Analyzer.scala | 11 +-- .../catalog/v2/LookupCatalogSuite.scala | 88 +++++++++++++++++ .../v2/ResolveMultipartIdentifierSuite.scala | 99 ------------------- .../datasources/DataSourceResolution.scala | 2 +- 5 files changed, 105 insertions(+), 123 deletions(-) create mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala delete mode 100644 sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala index 932d32022702b..5464a7496d23d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalog/v2/LookupCatalog.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier @Experimental trait LookupCatalog { - def lookupCatalog: Option[(String) => CatalogPlugin] = None + protected def lookupCatalog(name: String): CatalogPlugin type CatalogObjectIdentifier = (Option[CatalogPlugin], Identifier) @@ -34,27 +34,23 @@ trait LookupCatalog { * Extract catalog plugin and identifier from a multi-part identifier. */ object CatalogObjectIdentifier { - def unapply(parts: Seq[String]): Option[CatalogObjectIdentifier] = lookupCatalog.map { lookup => - parts match { - case Seq(name) => - (None, Identifier.of(Array.empty, name)) - case Seq(catalogName, tail @ _*) => - try { - val catalog = lookup(catalogName) - (Some(catalog), Identifier.of(tail.init.toArray, tail.last)) - } catch { - case _: CatalogNotFoundException => - (None, Identifier.of(parts.init.toArray, parts.last)) - } - } + def unapply(parts: Seq[String]): Some[CatalogObjectIdentifier] = parts match { + case Seq(name) => + Some((None, Identifier.of(Array.empty, name))) + case Seq(catalogName, tail @ _*) => + try { + Some((Some(lookupCatalog(catalogName)), Identifier.of(tail.init.toArray, tail.last))) + } catch { + case _: CatalogNotFoundException => + Some((None, Identifier.of(parts.init.toArray, parts.last))) + } } } /** * Extract legacy table identifier from a multi-part identifier. * - * For legacy support only. Please use - * [[org.apache.spark.sql.catalog.v2.LookupCatalog.CatalogObjectIdentifier]] in DSv2 code paths. + * For legacy support only. Please use [[CatalogObjectIdentifier]] instead on DSv2 code paths. */ object AsTableIdentifier { def unapply(parts: Seq[String]): Option[TableIdentifier] = parts match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index b60dd272c7a08..e0c0ad6efb483 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -24,7 +24,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Random import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalog.v2.{CatalogPlugin, LookupCatalog} +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, LookupCatalog} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.encoders.OuterScopes @@ -96,18 +96,15 @@ object AnalysisContext { class Analyzer( catalog: SessionCatalog, conf: SQLConf, - maxIterations: Int, - override val lookupCatalog: Option[(String) => CatalogPlugin] = None) + maxIterations: Int) extends RuleExecutor[LogicalPlan] with CheckAnalysis with LookupCatalog { def this(catalog: SessionCatalog, conf: SQLConf) = { this(catalog, conf, conf.optimizerMaxIterations) } - def this(lookupCatalog: Option[(String) => CatalogPlugin], catalog: SessionCatalog, - conf: SQLConf) = { - this(catalog, conf, conf.optimizerMaxIterations, lookupCatalog) - } + override protected def lookupCatalog(name: String): CatalogPlugin = + throw new CatalogNotFoundException("No catalog lookup function") def executeAndCheck(plan: LogicalPlan, tracker: QueryPlanningTracker): LogicalPlan = { AnalysisHelper.markInAnalyzer { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala new file mode 100644 index 0000000000000..783751ff79865 --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/LookupCatalogSuite.scala @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.catalyst.catalog.v2 + +import org.scalatest.Inside +import org.scalatest.Matchers._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, LookupCatalog} +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +private case class TestCatalogPlugin(override val name: String) extends CatalogPlugin { + + override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit +} + +class LookupCatalogSuite extends SparkFunSuite with LookupCatalog with Inside { + import CatalystSqlParser._ + + private val catalogs = Seq("prod", "test").map(x => x -> new TestCatalogPlugin(x)).toMap + + override def lookupCatalog(name: String): CatalogPlugin = + catalogs.getOrElse(name, throw new CatalogNotFoundException(s"$name not found")) + + test("catalog object identifier") { + Seq( + ("tbl", None, Seq.empty, "tbl"), + ("db.tbl", None, Seq("db"), "tbl"), + ("prod.func", catalogs.get("prod"), Seq.empty, "func"), + ("ns1.ns2.tbl", None, Seq("ns1", "ns2"), "tbl"), + ("prod.db.tbl", catalogs.get("prod"), Seq("db"), "tbl"), + ("test.db.tbl", catalogs.get("test"), Seq("db"), "tbl"), + ("test.ns1.ns2.ns3.tbl", catalogs.get("test"), Seq("ns1", "ns2", "ns3"), "tbl"), + ("`db.tbl`", None, Seq.empty, "db.tbl"), + ("parquet.`file:/tmp/db.tbl`", None, Seq("parquet"), "file:/tmp/db.tbl"), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, + Seq("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json")).foreach { + case (sql, expectedCatalog, namespace, name) => + inside(parseMultipartIdentifier(sql)) { + case CatalogObjectIdentifier(catalog, ident) => + catalog shouldEqual expectedCatalog + ident shouldEqual Identifier.of(namespace.toArray, name) + } + } + } + + test("table identifier") { + Seq( + ("tbl", "tbl", None), + ("db.tbl", "tbl", Some("db")), + ("`db.tbl`", "db.tbl", None), + ("parquet.`file:/tmp/db.tbl`", "file:/tmp/db.tbl", Some("parquet")), + ("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", "s3://buck/tmp/abc.json", + Some("org.apache.spark.sql.json"))).foreach { + case (sql, table, db) => + inside (parseMultipartIdentifier(sql)) { + case AsTableIdentifier(ident) => + ident shouldEqual TableIdentifier(table, db) + } + } + Seq( + "prod.func", + "prod.db.tbl", + "ns1.ns2.tbl").foreach { sql => + parseMultipartIdentifier(sql) match { + case AsTableIdentifier(_) => + fail(s"$sql should not be resolved as TableIdentifier") + case _ => + } + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala deleted file mode 100644 index 0f2d67eaa9b20..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/v2/ResolveMultipartIdentifierSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.catalyst.catalog.v2 - -import org.scalatest.Matchers._ - -import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin} -import org.apache.spark.sql.catalyst.TableIdentifier -import org.apache.spark.sql.catalyst.analysis.{AnalysisTest, Analyzer} -import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.util.CaseInsensitiveStringMap - -private class TestCatalogPlugin(override val name: String) extends CatalogPlugin { - - override def initialize(name: String, options: CaseInsensitiveStringMap): Unit = Unit -} - -class ResolveMultipartIdentifierSuite extends AnalysisTest { - import CatalystSqlParser._ - - private val analyzer = makeAnalyzer(caseSensitive = false) - - private val catalogs = Seq("prod", "test").map(name => name -> new TestCatalogPlugin(name)).toMap - - private def lookupCatalog(catalog: String): CatalogPlugin = - catalogs.getOrElse(catalog, throw new CatalogNotFoundException("Not found")) - - private def makeAnalyzer(caseSensitive: Boolean) = { - val conf = new SQLConf().copy(SQLConf.CASE_SENSITIVE -> caseSensitive) - new Analyzer(Some(lookupCatalog _), null, conf) - } - - override protected def getAnalyzer(caseSensitive: Boolean) = analyzer - - private def checkResolution(sqlText: String, expectedCatalog: Option[CatalogPlugin], - expectedNamespace: Array[String], expectedName: String): Unit = { - - import analyzer.CatalogObjectIdentifier - val CatalogObjectIdentifier(catalog, ident) = parseMultipartIdentifier(sqlText) - catalog shouldEqual expectedCatalog - ident.namespace shouldEqual expectedNamespace - ident.name shouldEqual expectedName - } - - private def checkTableResolution(sqlText: String, - expectedIdent: Option[TableIdentifier]): Unit = { - - import analyzer.AsTableIdentifier - parseMultipartIdentifier(sqlText) match { - case AsTableIdentifier(ident) => - assert(Some(ident) === expectedIdent) - case _ => - assert(None === expectedIdent) - } - } - - test("resolve multipart identifier") { - checkResolution("tbl", None, Array.empty, "tbl") - checkResolution("db.tbl", None, Array("db"), "tbl") - checkResolution("prod.func", catalogs.get("prod"), Array.empty, "func") - checkResolution("ns1.ns2.tbl", None, Array("ns1", "ns2"), "tbl") - checkResolution("prod.db.tbl", catalogs.get("prod"), Array("db"), "tbl") - checkResolution("test.db.tbl", catalogs.get("test"), Array("db"), "tbl") - checkResolution("test.ns1.ns2.ns3.tbl", - catalogs.get("test"), Array("ns1", "ns2", "ns3"), "tbl") - checkResolution("`db.tbl`", None, Array.empty, "db.tbl") - checkResolution("parquet.`file:/tmp/db.tbl`", None, Array("parquet"), "file:/tmp/db.tbl") - checkResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", None, - Array("org.apache.spark.sql.json"), "s3://buck/tmp/abc.json") - } - - test("resolve table identifier") { - checkTableResolution("tbl", Some(TableIdentifier("tbl"))) - checkTableResolution("db.tbl", Some(TableIdentifier("tbl", Some("db")))) - checkTableResolution("prod.func", None) - checkTableResolution("ns1.ns2.tbl", None) - checkTableResolution("prod.db.tbl", None) - checkTableResolution("`db.tbl`", Some(TableIdentifier("db.tbl"))) - checkTableResolution("parquet.`file:/tmp/db.tbl`", - Some(TableIdentifier("file:/tmp/db.tbl", Some("parquet")))) - checkTableResolution("`org.apache.spark.sql.json`.`s3://buck/tmp/abc.json`", - Some(TableIdentifier("s3://buck/tmp/abc.json", Some("org.apache.spark.sql.json")))) - } -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index c8f90b84679f4..7583accda07c1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -41,7 +41,7 @@ case class DataSourceResolution( import org.apache.spark.sql.catalog.v2.CatalogV2Implicits._ - override def lookupCatalog: Option[String => CatalogPlugin] = Some(findCatalog) + override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name) override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( From 0db2aa05bebe64c27534abebd1a037165c7d19c7 Mon Sep 17 00:00:00 2001 From: John Zhuge Date: Fri, 31 May 2019 00:56:07 +0800 Subject: [PATCH 46/70] [SPARK-27813][SQL] DataSourceV2: Add DropTable logical operation ## What changes were proposed in this pull request? Support DROP TABLE from V2 catalogs. Move DROP TABLE into catalyst. Move parsing tests for DROP TABLE/VIEW to PlanResolutionSuite to validate existing behavior. Add new tests fo catalyst parser suite. Separate DROP VIEW into different code path from DROP TABLE. Move DROP VIEW into catalyst as a new operator. Add a meaningful exception to indicate view is not currently supported in v2 catalog. ## How was this patch tested? New unit tests. Existing unit tests in catalyst and sql core. Closes #24686 from jzhuge/SPARK-27813-pr. Authored-by: John Zhuge Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/parser/SqlBase.g4 | 4 +- .../spark/sql/catalog/v2/IdentifierImpl.java | 17 +++++ .../sql/catalyst/parser/AstBuilder.scala | 20 +++++- .../plans/logical/basicLogicalOperators.scala | 8 +++ .../logical/sql/DropTableStatement.scala | 34 ++++++++++ .../plans/logical/sql/DropViewStatement.scala | 33 ++++++++++ .../sql/catalyst/parser/DDLParserSuite.scala | 34 +++++++++- .../spark/sql/execution/SparkSqlParser.scala | 11 ---- .../datasources/DataSourceResolution.scala | 19 +++++- .../datasources/v2/DataSourceV2Strategy.scala | 5 +- .../datasources/v2/DropTableExec.scala | 44 +++++++++++++ .../execution/command/DDLParserSuite.scala | 63 +----------------- .../command/PlanResolutionSuite.scala | 65 ++++++++++++++++++- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 18 ++++- 14 files changed, 294 insertions(+), 81 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index cce108b0e1f03..4133331c7fc40 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -136,8 +136,8 @@ statement DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions - | DROP TABLE (IF EXISTS)? tableIdentifier PURGE? #dropTable - | DROP VIEW (IF EXISTS)? tableIdentifier #dropTable + | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable + | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView | CREATE (OR REPLACE)? (GLOBAL? TEMPORARY)? VIEW (IF NOT EXISTS)? tableIdentifier identifierCommentList? (COMMENT STRING)? diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java index cd131432008a6..34f3882c9c412 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalog/v2/IdentifierImpl.java @@ -22,6 +22,8 @@ import java.util.Arrays; import java.util.Objects; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * An {@link Identifier} implementation. @@ -49,6 +51,21 @@ public String name() { return name; } + private String escapeQuote(String part) { + if (part.contains("`")) { + return part.replace("`", "``"); + } else { + return part; + } + } + + @Override + public String toString() { + return Stream.concat(Stream.of(namespace), Stream.of(name)) + .map(part -> '`' + escapeQuote(part) + '`') + .collect(Collectors.joining(".")); + } + @Override public boolean equals(Object o) { if (this == o) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 270c99d6cca8c..81ec2a1d9c904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -2151,4 +2151,22 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging } } + /** + * Create a [[DropTableStatement]] command. + */ + override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { + DropTableStatement( + visitMultipartIdentifier(ctx.multipartIdentifier()), + ctx.EXISTS != null, + ctx.PURGE != null) + } + + /** + * Create a [[DropViewStatement]] command. + */ + override def visitDropView(ctx: DropViewContext): AnyRef = withOrigin(ctx) { + DropViewStatement( + visitMultipartIdentifier(ctx.multipartIdentifier()), + ctx.EXISTS != null) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index 7c2b047cdd3cc..256d3261055e2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -484,6 +484,14 @@ object OverwritePartitionsDynamic { } } +/** + * Drop a table. + */ +case class DropTable( + catalog: TableCatalog, + ident: Identifier, + ifExists: Boolean) extends Command + /** * Insert some data into a table. Note that this plan is unresolved and has to be replaced by the diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala new file mode 100644 index 0000000000000..d41e8a5010257 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropTableStatement.scala @@ -0,0 +1,34 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A DROP TABLE statement, as parsed from SQL. + */ +case class DropTableStatement( + tableName: Seq[String], + ifExists: Boolean, + purge: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala new file mode 100644 index 0000000000000..523158788e834 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/DropViewStatement.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan + +/** + * A DROP VIEW statement, as parsed from SQL. + */ +case class DropViewStatement( + viewName: Seq[String], + ifExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 08baebbf140e6..35cd813ae65c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -20,7 +20,8 @@ package org.apache.spark.sql.catalyst.parser import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String @@ -34,6 +35,10 @@ class DDLParserSuite extends AnalysisTest { } } + private def parseCompare(sql: String, expected: LogicalPlan): Unit = { + comparePlans(parsePlan(sql), expected, checkAnalysis = false) + } + test("create table using - schema") { val sql = "CREATE TABLE my_tab(a INT COMMENT 'test', b STRING) USING parquet" @@ -362,4 +367,31 @@ class DDLParserSuite extends AnalysisTest { } } } + + test("drop table") { + parseCompare("DROP TABLE testcat.ns1.ns2.tbl", + DropTableStatement(Seq("testcat", "ns1", "ns2", "tbl"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE db.tab", + DropTableStatement(Seq("db", "tab"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE IF EXISTS db.tab", + DropTableStatement(Seq("db", "tab"), ifExists = true, purge = false)) + parseCompare(s"DROP TABLE tab", + DropTableStatement(Seq("tab"), ifExists = false, purge = false)) + parseCompare(s"DROP TABLE IF EXISTS tab", + DropTableStatement(Seq("tab"), ifExists = true, purge = false)) + parseCompare(s"DROP TABLE tab PURGE", + DropTableStatement(Seq("tab"), ifExists = false, purge = true)) + parseCompare(s"DROP TABLE IF EXISTS tab PURGE", + DropTableStatement(Seq("tab"), ifExists = true, purge = true)) + } + + test("drop view") { + parseCompare(s"DROP VIEW testcat.db.view", + DropViewStatement(Seq("testcat", "db", "view"), ifExists = false)) + parseCompare(s"DROP VIEW db.view", DropViewStatement(Seq("db", "view"), ifExists = false)) + parseCompare(s"DROP VIEW IF EXISTS db.view", + DropViewStatement(Seq("db", "view"), ifExists = true)) + parseCompare(s"DROP VIEW view", DropViewStatement(Seq("view"), ifExists = false)) + parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index dd273937f9788..ac61661e83e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -646,17 +646,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx.TEMPORARY != null) } - /** - * Create a [[DropTableCommand]] command. - */ - override def visitDropTable(ctx: DropTableContext): LogicalPlan = withOrigin(ctx) { - DropTableCommand( - visitTableIdentifier(ctx.tableIdentifier), - ctx.EXISTS != null, - ctx.VIEW != null, - ctx.PURGE != null) - } - /** * Create a [[AlterTableRenameCommand]] command. * diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7583accda07c1..7d34b6568a4fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -27,9 +27,10 @@ import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution.command.DropTableCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider import org.apache.spark.sql.types.StructType @@ -81,6 +82,20 @@ case class DataSourceResolution( s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) .asTableCatalog convertCTAS(catalog, identifier, create) + + case DropTableStatement(CatalogObjectIdentifier(Some(catalog), ident), ifExists, _) => + DropTable(catalog.asTableCatalog, ident, ifExists) + + case DropTableStatement(AsTableIdentifier(tableName), ifExists, purge) => + DropTableCommand(tableName, ifExists, isView = false, purge) + + case DropViewStatement(CatalogObjectIdentifier(Some(catalog), ident), _) => + throw new AnalysisException( + s"Can not specify catalog `${catalog.name}` for view $ident " + + s"because view support in catalog has not been implemented yet") + + case DropViewStatement(AsTableIdentifier(tableName), ifExists) => + DropTableCommand(tableName, ifExists, isView = true, purge = false) } object V1WriteProvider { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index c5cd328a6fee4..d78b95336a76e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -23,7 +23,7 @@ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} import org.apache.spark.sql.catalyst.planning.PhysicalOperation -import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} +import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} import org.apache.spark.sql.execution.datasources.DataSourceStrategy import org.apache.spark.sql.execution.streaming.continuous.{ContinuousCoalesceExec, WriteToContinuousDataSource, WriteToContinuousDataSourceExec} @@ -187,6 +187,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { Nil } + case DropTable(catalog, ident, ifExists) => + DropTableExec(catalog, ident, ifExists) :: Nil + case _ => Nil } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala new file mode 100644 index 0000000000000..d325e0205f9d8 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DropTableExec.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalog.v2.{Identifier, TableCatalog} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.analysis.NoSuchTableException +import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.execution.LeafExecNode + +/** + * Physical plan node for dropping a table. + */ +case class DropTableExec(catalog: TableCatalog, ident: Identifier, ifExists: Boolean) + extends LeafExecNode { + + override def doExecute(): RDD[InternalRow] = { + if (catalog.tableExists(ident)) { + catalog.dropTable(ident) + } else if (!ifExists) { + throw new NoSuchTableException(ident) + } + + sqlContext.sparkContext.parallelize(Seq.empty, 1) + } + + override def output: Seq[Attribute] = Seq.empty +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index d430eeb294e13..0dd11c1e518e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -32,13 +32,12 @@ import org.apache.spark.sql.catalyst.dsl.plans.DslLogicalPlan import org.apache.spark.sql.catalyst.expressions.JsonTuple import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.logical.{Project, ScriptTransformation} +import org.apache.spark.sql.catalyst.plans.logical.{Generate, InsertIntoDir, LogicalPlan, Project, ScriptTransformation} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.execution.datasources.CreateTable import org.apache.spark.sql.internal.{HiveSerDe, SQLConf} import org.apache.spark.sql.test.SharedSQLContext -import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{IntegerType, StructField, StructType} class DDLParserSuite extends PlanTest with SharedSQLContext { @@ -887,64 +886,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(e.contains("Found an empty partition key 'b'")) } - test("drop table") { - val tableName1 = "db.tab" - val tableName2 = "tab" - - val parsed = Seq( - s"DROP TABLE $tableName1", - s"DROP TABLE IF EXISTS $tableName1", - s"DROP TABLE $tableName2", - s"DROP TABLE IF EXISTS $tableName2", - s"DROP TABLE $tableName2 PURGE", - s"DROP TABLE IF EXISTS $tableName2 PURGE" - ).map(parser.parsePlan) - - val expected = Seq( - DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = false, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", Option("db")), ifExists = true, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false, - purge = false), - DropTableCommand(TableIdentifier("tab", None), ifExists = false, isView = false, - purge = true), - DropTableCommand(TableIdentifier("tab", None), ifExists = true, isView = false, - purge = true)) - - parsed.zip(expected).foreach { case (p, e) => comparePlans(p, e) } - } - - test("drop view") { - val viewName1 = "db.view" - val viewName2 = "view" - - val parsed1 = parser.parsePlan(s"DROP VIEW $viewName1") - val parsed2 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName1") - val parsed3 = parser.parsePlan(s"DROP VIEW $viewName2") - val parsed4 = parser.parsePlan(s"DROP VIEW IF EXISTS $viewName2") - - val expected1 = - DropTableCommand(TableIdentifier("view", Option("db")), ifExists = false, isView = true, - purge = false) - val expected2 = - DropTableCommand(TableIdentifier("view", Option("db")), ifExists = true, isView = true, - purge = false) - val expected3 = - DropTableCommand(TableIdentifier("view", None), ifExists = false, isView = true, - purge = false) - val expected4 = - DropTableCommand(TableIdentifier("view", None), ifExists = true, isView = true, - purge = false) - - comparePlans(parsed1, expected1) - comparePlans(parsed2, expected2) - comparePlans(parsed3, expected3) - comparePlans(parsed4, expected4) - } - test("show columns") { val sql1 = "SHOW COLUMNS FROM t1" val sql2 = "SHOW COLUMNS IN db1.t1" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index ebd21d8a1d53c..60801910c6dbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.command import java.net.URI +import java.util.Locale import org.apache.spark.sql.{AnalysisException, SaveMode} import org.apache.spark.sql.catalog.v2.{CatalogNotFoundException, CatalogPlugin, Identifier, TableCatalog, TestTableCatalog} @@ -25,7 +26,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogStorageFormat, CatalogTable, CatalogTableType} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser -import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} import org.apache.spark.sql.execution.datasources.{CreateTable, DataSourceResolution} import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType, StructType} @@ -55,6 +56,9 @@ class PlanResolutionSuite extends AnalysisTest { DataSourceResolution(newConf, lookupCatalog).apply(parsePlan(query)) } + private def parseResolveCompare(query: String, expected: LogicalPlan): Unit = + comparePlans(parseAndResolve(query), expected, checkAnalysis = true) + private def extractTableDesc(sql: String): (CatalogTable, Boolean) = { parseAndResolve(sql).collect { case CreateTable(tableDesc, mode, _) => (tableDesc, mode == SaveMode.Ignore) @@ -439,4 +443,63 @@ class PlanResolutionSuite extends AnalysisTest { s"got ${other.getClass.getName}: $sql") } } + + test("drop table") { + val tableName1 = "db.tab" + val tableIdent1 = TableIdentifier("tab", Option("db")) + val tableName2 = "tab" + val tableIdent2 = TableIdentifier("tab", None) + + parseResolveCompare(s"DROP TABLE $tableName1", + DropTableCommand(tableIdent1, ifExists = false, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1", + DropTableCommand(tableIdent1, ifExists = true, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE $tableName2", + DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2", + DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = false)) + parseResolveCompare(s"DROP TABLE $tableName2 PURGE", + DropTableCommand(tableIdent2, ifExists = false, isView = false, purge = true)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2 PURGE", + DropTableCommand(tableIdent2, ifExists = true, isView = false, purge = true)) + } + + test("drop table in v2 catalog") { + val tableName1 = "testcat.db.tab" + val tableIdent1 = Identifier.of(Array("db"), "tab") + val tableName2 = "testcat.tab" + val tableIdent2 = Identifier.of(Array.empty, "tab") + + parseResolveCompare(s"DROP TABLE $tableName1", + DropTable(testCat, tableIdent1, ifExists = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName1", + DropTable(testCat, tableIdent1, ifExists = true)) + parseResolveCompare(s"DROP TABLE $tableName2", + DropTable(testCat, tableIdent2, ifExists = false)) + parseResolveCompare(s"DROP TABLE IF EXISTS $tableName2", + DropTable(testCat, tableIdent2, ifExists = true)) + } + + test("drop view") { + val viewName1 = "db.view" + val viewIdent1 = TableIdentifier("view", Option("db")) + val viewName2 = "view" + val viewIdent2 = TableIdentifier("view") + + parseResolveCompare(s"DROP VIEW $viewName1", + DropTableCommand(viewIdent1, ifExists = false, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW IF EXISTS $viewName1", + DropTableCommand(viewIdent1, ifExists = true, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW $viewName2", + DropTableCommand(viewIdent2, ifExists = false, isView = true, purge = false)) + parseResolveCompare(s"DROP VIEW IF EXISTS $viewName2", + DropTableCommand(viewIdent2, ifExists = true, isView = true, purge = false)) + } + + test("drop view in v2 catalog") { + intercept[AnalysisException] { + parseAndResolve("DROP VIEW testcat.db.view") + }.getMessage.toLowerCase(Locale.ROOT).contains( + "view support in catalog has not been implemented") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index 606d2ad790b6b..eaef458d38386 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -23,7 +23,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalog.v2.Identifier -import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException +import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, TableAlreadyExistsException} import org.apache.spark.sql.execution.datasources.v2.orc.OrcDataSourceV2 import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types.{LongType, StringType, StructType} @@ -267,4 +267,20 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn conf.setConfString("spark.sql.default.catalog", originalDefaultCatalog) } } + + test("DropTable: basic") { + val tableName = "testcat.ns1.ns2.tbl" + val ident = Identifier.of(Array("ns1", "ns2"), "tbl") + sql(s"CREATE TABLE $tableName USING foo AS SELECT id, data FROM source") + assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === true) + sql(s"DROP TABLE $tableName") + assert(spark.catalog("testcat").asTableCatalog.tableExists(ident) === false) + } + + test("DropTable: if exists") { + intercept[NoSuchTableException] { + sql(s"DROP TABLE testcat.db.notbl") + } + sql(s"DROP TABLE IF EXISTS testcat.db.notbl") + } } From d9e0cca491783a7af86d3f008f5e595d3d9e6cd0 Mon Sep 17 00:00:00 2001 From: SongYadong Date: Fri, 8 Mar 2019 10:51:39 -0800 Subject: [PATCH 47/70] [SPARK-27103][SQL][MINOR] List SparkSql reserved keywords in alphabet order ## What changes were proposed in this pull request? This PR tries to correct spark-sql reserved keywords' position in list if they are not in alphabetical order. In test suite some repeated words are removed. Also some comments are added for remind. ## How was this patch tested? Existing unit tests. Closes #23985 from SongYadong/sql_reserved_alphabet. Authored-by: SongYadong Signed-off-by: Dongjoon Hyun --- .../spark/sql/catalyst/parser/SqlBase.g4 | 501 ++++++++++++++++-- .../parser/TableIdentifierParserSuite.scala | 288 ++++++++-- 2 files changed, 718 insertions(+), 71 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 4133331c7fc40..04fbdd2ddd15f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -788,62 +788,485 @@ number // // Let's say you add a new token `NEWTOKEN` and this is not reserved regardless of a `spark.sql.parser.ansi.enabled` // value. In this case, you must add a token `NEWTOKEN` in both `ansiNonReserved` and `nonReserved`. +// +// It is recommended to list them in alphabetical order. // The list of the reserved keywords when `spark.sql.parser.ansi.enabled` is true. Currently, we only reserve // the ANSI keywords that almost all the ANSI SQL standards (SQL-92, SQL-99, SQL-2003, SQL-2008, SQL-2011, // and SQL-2016) and PostgreSQL reserve. ansiReserved - : ALL | AND | ANTI | ANY | AS | AUTHORIZATION | BOTH | CASE | CAST | CHECK | COLLATE | COLUMN | CONSTRAINT | CREATE - | CROSS | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | CURRENT_USER | DISTINCT | ELSE | END | EXCEPT | FALSE - | FETCH | FOR | FOREIGN | FROM | FULL | GRANT | GROUP | HAVING | IN | INNER | INTERSECT | INTO | JOIN | IS - | LEADING | LEFT | NATURAL | NOT | NULL | ON | ONLY | OR | ORDER | OUTER | OVERLAPS | PRIMARY | REFERENCES | RIGHT - | SELECT | SEMI | SESSION_USER | SETMINUS | SOME | TABLE | THEN | TO | TRAILING | UNION | UNIQUE | USER | USING - | WHEN | WHERE | WITH + : ALL + | AND + | ANTI + | ANY + | AS + | AUTHORIZATION + | BOTH + | CASE + | CAST + | CHECK + | COLLATE + | COLUMN + | CONSTRAINT + | CREATE + | CROSS + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DISTINCT + | ELSE + | END + | EXCEPT + | FALSE + | FETCH + | FOR + | FOREIGN + | FROM + | FULL + | GRANT + | GROUP + | HAVING + | IN + | INNER + | INTERSECT + | INTO + | IS + | JOIN + | LEADING + | LEFT + | NATURAL + | NOT + | NULL + | ON + | ONLY + | OR + | ORDER + | OUTER + | OVERLAPS + | PRIMARY + | REFERENCES + | RIGHT + | SELECT + | SEMI + | SESSION_USER + | SETMINUS + | SOME + | TABLE + | THEN + | TO + | TRAILING + | UNION + | UNIQUE + | USER + | USING + | WHEN + | WHERE + | WITH ; // The list of the non-reserved keywords when `spark.sql.parser.ansi.enabled` is true. ansiNonReserved - : ADD | AFTER | ALTER | ANALYZE | ARCHIVE | ARRAY | ASC | AT | BETWEEN | BUCKET | BUCKETS | BY | CACHE | CASCADE - | CHANGE | CLEAR | CLUSTER | CLUSTERED | CODEGEN | COLLECTION | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS - | COMPUTE | CONCATENATE | COST | CUBE | CURRENT | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE - | DELIMITED | DESC | DESCRIBE | DFS | DIRECTORIES | DIRECTORY | DISTRIBUTE | DIV | DROP | ESCAPED | EXCHANGE - | EXISTS | EXPLAIN | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FORMAT - | FORMATTED | FUNCTION | FUNCTIONS | GLOBAL | GROUPING | IF | IGNORE | IMPORT | INDEX | INDEXES | INPATH - | INPUTFORMAT | INSERT | INTERVAL | ITEMS | KEYS | LAST | LATERAL | LAZY | LIKE | LIMIT | LINES | LIST | LOAD - | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NULLS | OF | OPTION | OPTIONS | OUT - | OUTPUTFORMAT | OVER | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENT | PERCENTLIT | PIVOT | PRECEDING - | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE - | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SCHEMA | SEPARATED | SERDE - | SERDEPROPERTIES | SET | SETS | SHOW | SKEWED | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT - | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY | TERMINATED | TOUCH | TRANSACTION | TRANSACTIONS | TRANSFORM - | TRUE | TRUNCATE | UNARCHIVE | UNBOUNDED | UNCACHE | UNLOCK | UNSET | USE | VALUES | VIEW | WINDOW + : ADD + | AFTER + | ALTER + | ANALYZE + | ARCHIVE + | ARRAY + | ASC + | AT + | BETWEEN + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CHANGE + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLECTION + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | COST + | CUBE + | CURRENT + | DATA + | DATABASE + | DATABASES + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTRIBUTE + | DIV + | DROP + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FORMAT + | FORMATTED + | FUNCTION + | FUNCTIONS + | GLOBAL + | GROUPING + | IF + | IGNORE + | IMPORT + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | ITEMS + | KEYS + | LAST + | LATERAL + | LAZY + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MSCK + | NO + | NULLS + | OF + | OPTION + | OPTIONS + | OUT + | OUTPUTFORMAT + | OVER + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENT + | PERCENTLIT + | PIVOT + | PRECEDING + | PRINCIPALS + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SCHEMA + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SET + | SETS + | SHOW + | SKEWED + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | TOUCH + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRUE + | TRUNCATE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNLOCK + | UNSET + | USE + | VALUES + | VIEW + | WINDOW ; defaultReserved - : ANTI | CROSS | EXCEPT | FULL | INNER | INTERSECT | JOIN | LEFT | NATURAL | ON | RIGHT | SEMI | SETMINUS | UNION + : ANTI + | CROSS + | EXCEPT + | FULL + | INNER + | INTERSECT + | JOIN + | LEFT + | NATURAL + | ON + | RIGHT + | SEMI + | SETMINUS + | UNION | USING ; nonReserved - : ADD | AFTER | ALL | ALTER | ANALYZE | AND | ANY | ARCHIVE | ARRAY | AS | ASC | AT | AUTHORIZATION | BETWEEN - | BOTH | BUCKET | BUCKETS | BY | CACHE | CASCADE | CASE | CAST | CHANGE | CHECK | CLEAR | CLUSTER | CLUSTERED - | CODEGEN | COLLATE | COLLECTION | COLUMN | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS | COMPUTE - | CONCATENATE | CONSTRAINT | COST | CREATE | CUBE | CURRENT | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP - | CURRENT_USER | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE | DELIMITED | DESC | DESCRIBE | DFS - | DIRECTORIES | DIRECTORY | DISTINCT | DISTRIBUTE | DIV | DROP | ELSE | END | ESCAPED | EXCHANGE | EXISTS | EXPLAIN - | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FALSE | FETCH | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FOR | FOREIGN - | FORMAT | FORMATTED | FROM | FUNCTION | FUNCTIONS | GLOBAL | GRANT | GROUP | GROUPING | HAVING | IF | IGNORE - | IMPORT | IN | INDEX | INDEXES | INPATH | INPUTFORMAT | INSERT | INTERVAL | INTO | IS | ITEMS | KEYS | LAST - | LATERAL | LAZY | LEADING | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO - | MAP | MSCK | NO | NOT | NULL | NULLS | OF | ONLY | OPTION | OPTIONS | OR | ORDER | OUT | OUTER | OUTPUTFORMAT - | OVER | OVERLAPS | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT | PIVOT | POSITION | PRECEDING - | PRIMARY | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH - | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS - | SELECT | SEPARATED | SERDE | SERDEPROPERTIES | SESSION_USER | SET | SETS | SHOW | SKEWED | SOME | SORT | SORTED - | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLE | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY - | TERMINATED | THEN | TO | TOUCH | TRAILING | TRANSACTION | TRANSACTIONS | TRANSFORM | TRUE | TRUNCATE | UNARCHIVE - | UNBOUNDED | UNCACHE | UNLOCK | UNIQUE | UNSET | USE | USER | VALUES | VIEW | WHEN | WHERE | WINDOW | WITH + : ADD + | AFTER + | ALL + | ALTER + | ANALYZE + | AND + | ANY + | ARCHIVE + | ARRAY + | AS + | ASC + | AT + | AUTHORIZATION + | BETWEEN + | BOTH + | BUCKET + | BUCKETS + | BY + | CACHE + | CASCADE + | CASE + | CAST + | CHANGE + | CHECK + | CLEAR + | CLUSTER + | CLUSTERED + | CODEGEN + | COLLATE + | COLLECTION + | COLUMN + | COLUMNS + | COMMENT + | COMMIT + | COMPACT + | COMPACTIONS + | COMPUTE + | CONCATENATE + | CONSTRAINT + | COST + | CREATE + | CUBE + | CURRENT + | CURRENT_DATE + | CURRENT_TIME + | CURRENT_TIMESTAMP + | CURRENT_USER + | DATA + | DATABASE + | DATABASES + | DBPROPERTIES + | DEFINED + | DELETE + | DELIMITED + | DESC + | DESCRIBE + | DFS + | DIRECTORIES + | DIRECTORY + | DISTINCT + | DISTRIBUTE + | DIV + | DROP + | ELSE + | END + | ESCAPED + | EXCHANGE + | EXISTS + | EXPLAIN + | EXPORT + | EXTENDED + | EXTERNAL + | EXTRACT + | FALSE + | FETCH + | FIELDS + | FILEFORMAT + | FIRST + | FOLLOWING + | FOR + | FOREIGN + | FORMAT + | FORMATTED + | FROM + | FUNCTION + | FUNCTIONS + | GLOBAL + | GRANT + | GROUP + | GROUPING + | HAVING + | IF + | IGNORE + | IMPORT + | IN + | INDEX + | INDEXES + | INPATH + | INPUTFORMAT + | INSERT + | INTERVAL + | INTO + | IS + | ITEMS + | KEYS + | LAST + | LATERAL + | LAZY + | LEADING + | LIKE + | LIMIT + | LINES + | LIST + | LOAD + | LOCAL + | LOCATION + | LOCK + | LOCKS + | LOGICAL + | MACRO + | MAP + | MSCK + | NO + | NOT + | NULL + | NULLS + | OF + | ONLY + | OPTION + | OPTIONS + | OR + | ORDER + | OUT + | OUTER + | OUTPUTFORMAT + | OVER + | OVERLAPS + | OVERWRITE + | PARTITION + | PARTITIONED + | PARTITIONS + | PERCENTLIT + | PIVOT + | POSITION + | PRECEDING + | PRIMARY + | PRINCIPALS + | PURGE + | QUERY + | RANGE + | RECORDREADER + | RECORDWRITER + | RECOVER + | REDUCE + | REFERENCES + | REFRESH + | RENAME + | REPAIR + | REPLACE + | RESET + | RESTRICT + | REVOKE + | RLIKE + | ROLE + | ROLES + | ROLLBACK + | ROLLUP + | ROW + | ROWS + | SELECT + | SEPARATED + | SERDE + | SERDEPROPERTIES + | SESSION_USER + | SET + | SETS + | SHOW + | SKEWED + | SOME + | SORT + | SORTED + | START + | STATISTICS + | STORED + | STRATIFY + | STRUCT + | TABLE + | TABLES + | TABLESAMPLE + | TBLPROPERTIES + | TEMPORARY + | TERMINATED + | THEN + | TO + | TOUCH + | TRAILING + | TRANSACTION + | TRANSACTIONS + | TRANSFORM + | TRUE + | TRUNCATE + | UNARCHIVE + | UNBOUNDED + | UNCACHE + | UNIQUE + | UNLOCK + | UNSET + | USE + | USER + | VALUES + | VIEW + | WHEN + | WHERE + | WINDOW + | WITH ; SELECT: 'SELECT'; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 489b7f328f8fa..3d41c27f217d9 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -23,39 +23,263 @@ class TableIdentifierParserSuite extends SparkFunSuite { import CatalystSqlParser._ // Add "$elem$", "$value$" & "$key$" - val hiveNonReservedKeyword = Array("add", "admin", "after", "analyze", "archive", "asc", "before", - "bucket", "buckets", "cascade", "change", "cluster", "clustered", "clusterstatus", "collection", - "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "cost", - "data", "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", - "dependency", "desc", "directories", "directory", "disable", "distribute", - "enable", "escaped", "exclusive", "explain", "export", "fields", "file", "fileformat", "first", - "format", "formatted", "functions", "hold_ddltime", "hour", "idxproperties", "ignore", "index", - "indexes", "inpath", "inputdriver", "inputformat", "items", "jar", "keys", "key_type", "last", - "limit", "offset", "lines", "load", "location", "lock", "locks", "logical", "long", "mapjoin", - "materialized", "metadata", "minus", "minute", "month", "msck", "noscan", "no_drop", "nulls", - "offline", "option", "outputdriver", "outputformat", "overwrite", "owner", "partitioned", - "partitions", "plus", "pretty", "principals", "protection", "purge", "read", "readonly", - "rebuild", "recordreader", "recordwriter", "reload", "rename", "repair", "replace", - "replication", "restrict", "rewrite", "role", "roles", "schemas", "second", - "serde", "serdeproperties", "server", "sets", "shared", "show", "show_database", "skewed", - "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", - "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", - "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", - "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", - "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", - "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", - "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", - "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", - "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", - "of", "order", "out", "outer", "partition", "percent", "procedure", "query", "range", "reads", - "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", - "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", - "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", - "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") + // It is recommended to list them in alphabetical order. + val hiveNonReservedKeyword = Array( + "add", + "admin", + "after", + "all", + "alter", + "analyze", + "any", + "archive", + "array", + "as", + "asc", + "at", + "authorization", + "autocommit", + "before", + "between", + "bigint", + "binary", + "boolean", + "both", + "bucket", + "buckets", + "by", + "cascade", + "change", + "cluster", + "clustered", + "clusterstatus", + "collection", + "columns", + "comment", + "compact", + "compactions", + "compute", + "concatenate", + "continue", + "cost", + "create", + "cube", + "current_date", + "current_timestamp", + "cursor", + "data", + "databases", + "date", + "datetime", + "day", + "dbproperties", + "decimal", + "deferred", + "defined", + "delete", + "delimited", + "dependency", + "desc", + "describe", + "directories", + "directory", + "disable", + "distribute", + "double", + "drop", + "enable", + "escaped", + "exclusive", + "exists", + "explain", + "export", + "external", + "extract", + "false", + "fetch", + "fields", + "file", + "fileformat", + "first", + "float", + "for", + "format", + "formatted", + "functions", + "grant", + "group", + "grouping", + "hold_ddltime", + "hour", + "idxproperties", + "ignore", + "import", + "in", + "index", + "indexes", + "inpath", + "inputdriver", + "inputformat", + "insert", + "int", + "into", + "is", + "isolation", + "items", + "jar", + "key_type", + "keys", + "last", + "lateral", + "leading", + "level", + "like", + "limit", + "lines", + "load", + "local", + "location", + "lock", + "locks", + "logical", + "long", + "mapjoin", + "materialized", + "metadata", + "minus", + "minute", + "month", + "msck", + "no_drop", + "none", + "noscan", + "null", + "nulls", + "of", + "offline", + "offset", + "option", + "order", + "out", + "outer", + "outputdriver", + "outputformat", + "overwrite", + "owner", + "partition", + "partitioned", + "partitions", + "percent", + "pivot", + "plus", + "position", + "pretty", + "principals", + "procedure", + "protection", + "purge", + "query", + "range", + "read", + "readonly", + "reads", + "rebuild", + "recordreader", + "recordwriter", + "regexp", + "reload", + "rename", + "repair", + "replace", + "replication", + "restrict", + "revoke", + "rewrite", + "rlike", + "role", + "roles", + "rollup", + "row", + "rows", + "schemas", + "second", + "serde", + "serdeproperties", + "server", + "set", + "sets", + "shared", + "show", + "show_database", + "skewed", + "smallint", + "snapshot", + "sort", + "sorted", + "ssl", + "statistics", + "stored", + "streamtable", + "string", + "struct", + "table", + "tables", + "tblproperties", + "temporary", + "terminated", + "timestamp", + "tinyint", + "to", + "touch", + "trailing", + "transaction", + "transactions", + "trigger", + "true", + "truncate", + "unarchive", + "undo", + "uniontype", + "unlock", + "unset", + "unsigned", + "update", + "uri", + "use", + "user", + "utc", + "utctimestamp", + "values", + "view", + "while", + "with", + "work", + "write", + "year") - val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", - "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", - "where", "having", "from", "to", "table", "with", "not") + val hiveStrictNonReservedKeyword = Seq( + "anti", + "cross", + "database", + "except", + "from", + "full", + "having", + "inner", + "intersect", + "join", + "left", + "natural", + "not", + "on", + "right", + "select", + "semi", + "table", + "to", + "union", + "where", + "with") test("table identifier") { // Regular names. From e1365ba26e4fe04c8bbc891ec628195fb87d1cde Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Wed, 5 Jun 2019 13:21:30 -0700 Subject: [PATCH 48/70] [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst This moves parsing logic for `ALTER TABLE` into Catalyst and adds parsed logical plans for alter table changes that use multi-part identifiers. This PR is similar to SPARK-27108, PR #24029, that created parsed logical plans for create and CTAS. * Create parsed logical plans * Move parsing logic into Catalyst's AstBuilder * Convert to DataSource plans in DataSourceResolution * Parse `ALTER TABLE ... SET LOCATION ...` separately from the partition variant * Parse `ALTER TABLE ... ALTER COLUMN ... [TYPE dataType] [COMMENT comment]` [as discussed on the dev list](http://apache-spark-developers-list.1001551.n3.nabble.com/DISCUSS-Syntax-for-table-DDL-td25197.html#a25270) * Parse `ALTER TABLE ... RENAME COLUMN ... TO ...` * Parse `ALTER TABLE ... DROP COLUMNS ...` * Added new tests in Catalyst's `DDLParserSuite` * Moved converted plan tests from SQL `DDLParserSuite` to `PlanResolutionSuite` * Existing tests for regressions Closes #24723 from rdblue/SPARK-27857-add-alter-table-statements-in-catalyst. Authored-by: Ryan Blue Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 98 +++++++- .../sql/catalyst/parser/AstBuilder.scala | 162 ++++++++++++- .../logical/sql/AlterTableStatements.scala | 78 +++++++ .../logical/sql/AlterViewStatements.scala | 33 +++ .../logical/sql/CreateTableStatement.scala | 10 +- .../plans/logical/sql/ParsedStatement.scala | 5 + .../sql/catalyst/parser/DDLParserSuite.scala | 215 +++++++++++++++++- .../spark/sql/execution/SparkSqlParser.scala | 60 +---- .../datasources/DataSourceResolution.scala | 42 +++- .../execution/command/DDLParserSuite.scala | 63 +---- .../command/PlanResolutionSuite.scala | 76 +++++++ 11 files changed, 701 insertions(+), 141 deletions(-) create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala create mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 04fbdd2ddd15f..706e96318d5a7 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -110,14 +110,27 @@ statement LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze - | ALTER TABLE tableIdentifier - ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + columns=qualifiedColTypeWithPositionList #addTableColumns + | ALTER TABLE multipartIdentifier + ADD (COLUMN | COLUMNS) + '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns + | ALTER TABLE multipartIdentifier + RENAME COLUMN from=qualifiedName TO to=identifier #renameTableColumn + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) '(' columns=qualifiedNameList ')' #dropTableColumns + | ALTER TABLE multipartIdentifier + DROP (COLUMN | COLUMNS) columns=qualifiedNameList #dropTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER (TABLE | VIEW) tableIdentifier + | ALTER (TABLE | VIEW) multipartIdentifier SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER (TABLE | VIEW) tableIdentifier + | ALTER (TABLE | VIEW) multipartIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties + | ALTER TABLE multipartIdentifier + (ALTER | CHANGE) COLUMN? qualifiedName + (TYPE dataType)? (COMMENT comment=STRING)? colPosition? #alterTableColumn | ALTER TABLE tableIdentifier partitionSpec? CHANGE COLUMN? identifier colType colPosition? #changeColumn | ALTER TABLE tableIdentifier (partitionSpec)? @@ -134,7 +147,8 @@ statement DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions | ALTER VIEW tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions - | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation + | ALTER TABLE multipartIdentifier SET locationSpec #setTableLocation + | ALTER TABLE tableIdentifier partitionSpec SET locationSpec #setPartitionLocation | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView @@ -690,7 +704,7 @@ intervalValue ; colPosition - : FIRST | AFTER identifier + : FIRST | AFTER qualifiedName ; dataType @@ -700,6 +714,14 @@ dataType | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType ; +qualifiedColTypeWithPositionList + : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* + ; + +qualifiedColTypeWithPosition + : name=qualifiedName dataType (COMMENT comment=STRING)? colPosition? + ; + colTypeList : colType (',' colType)* ; @@ -752,6 +774,10 @@ frameBound | expression boundType=(PRECEDING | FOLLOWING) ; +qualifiedNameList + : qualifiedName (',' qualifiedName)* + ; + qualifiedName : identifier ('.' identifier)* ; @@ -1253,6 +1279,7 @@ nonReserved | TRANSFORM | TRUE | TRUNCATE + | TYPE | UNARCHIVE | UNBOUNDED | UNCACHE @@ -1379,6 +1406,7 @@ RESET: 'RESET'; DATA: 'DATA'; START: 'START'; TRANSACTION: 'TRANSACTION'; +<<<<<<< HEAD COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; @@ -1390,6 +1418,64 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; EXTRACT: 'EXTRACT'; +||||||| parent of 5d6758c0e7... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +WEEK: 'WEEK'; +WEEKS: 'WEEKS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +YEAR: 'YEAR'; +YEARS: 'YEARS'; +//============================ +// End of the keywords list +//============================ +======= +TRANSACTIONS: 'TRANSACTIONS'; +TRANSFORM: 'TRANSFORM'; +TRUE: 'TRUE'; +TRUNCATE: 'TRUNCATE'; +TYPE: 'TYPE'; +UNARCHIVE: 'UNARCHIVE'; +UNBOUNDED: 'UNBOUNDED'; +UNCACHE: 'UNCACHE'; +UNION: 'UNION'; +UNIQUE: 'UNIQUE'; +UNLOCK: 'UNLOCK'; +UNSET: 'UNSET'; +USE: 'USE'; +USER: 'USER'; +USING: 'USING'; +VALUES: 'VALUES'; +VIEW: 'VIEW'; +WEEK: 'WEEK'; +WEEKS: 'WEEKS'; +WHEN: 'WHEN'; +WHERE: 'WHERE'; +WINDOW: 'WINDOW'; +WITH: 'WITH'; +YEAR: 'YEAR'; +YEARS: 'YEARS'; +//============================ +// End of the keywords list +//============================ +>>>>>>> 5d6758c0e7... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 81ec2a1d9c904..4c1914841b58e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1991,6 +1991,13 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) } + /** + * Parse a qualified name to a multipart name. + */ + override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { + ctx.identifier.asScala.map(_.getText) + } + /** * Parse a list of transforms. */ @@ -2023,8 +2030,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging ctx.transforms.asScala.map { case identityCtx: IdentityTransformContext => - IdentityTransform(FieldReference( - identityCtx.qualifiedName.identifier.asScala.map(_.getText))) + IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) case applyCtx: ApplyTransformContext => val arguments = applyCtx.argument.asScala.map(visitTransformArgument) @@ -2071,7 +2077,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitTransformArgument(ctx: TransformArgumentContext): v2.expressions.Expression = { withOrigin(ctx) { val reference = Option(ctx.qualifiedName) - .map(nameCtx => FieldReference(nameCtx.identifier.asScala.map(_.getText))) + .map(typedVisit[Seq[String]]) + .map(FieldReference(_)) val literal = Option(ctx.constant) .map(typedVisit[Literal]) .map(lit => LiteralValue(lit.value, lit.dataType)) @@ -2169,4 +2176,151 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitMultipartIdentifier(ctx.multipartIdentifier()), ctx.EXISTS != null) } + + /** + * Parse new column info from ADD COLUMN into a QualifiedColType. + */ + override def visitQualifiedColTypeWithPosition( + ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { + if (ctx.colPosition != null) { + operationNotAllowed("ALTER TABLE table ADD COLUMN ... FIRST | AFTER otherCol", ctx) + } + + QualifiedColType( + typedVisit[Seq[String]](ctx.name), + typedVisit[DataType](ctx.dataType), + Option(ctx.comment).map(string)) + } + + /** + * Parse a [[AlterTableAddColumnsStatement]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]) + ) + } + + /** + * Parse a [[AlterTableRenameColumnStatement]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 RENAME COLUMN a.b.c TO x + * }}} + */ + override def visitRenameTableColumn( + ctx: RenameTableColumnContext): LogicalPlan = withOrigin(ctx) { + AlterTableRenameColumnStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + ctx.from.identifier.asScala.map(_.getText), + ctx.to.getText) + } + + /** + * Parse a [[AlterTableAlterColumnStatement]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint + * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment' + * ALTER TABLE table1 ALTER COLUMN a.b.c COMMENT 'new comment' + * }}} + */ + override def visitAlterTableColumn( + ctx: AlterTableColumnContext): LogicalPlan = withOrigin(ctx) { + val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER" + if (ctx.colPosition != null) { + operationNotAllowed(s"ALTER TABLE table $verb COLUMN ... FIRST | AFTER otherCol", ctx) + } + + if (ctx.dataType == null && ctx.comment == null) { + operationNotAllowed(s"ALTER TABLE table $verb COLUMN requires a TYPE or a COMMENT", ctx) + } + + AlterTableAlterColumnStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + typedVisit[Seq[String]](ctx.qualifiedName), + Option(ctx.dataType).map(typedVisit[DataType]), + Option(ctx.comment).map(string)) + } + + /** + * Parse a [[AlterTableDropColumnsStatement]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 DROP COLUMN a.b.c + * ALTER TABLE table1 DROP COLUMNS a.b.c, x, y + * }}} + */ + override def visitDropTableColumns( + ctx: DropTableColumnsContext): LogicalPlan = withOrigin(ctx) { + val columnsToDrop = ctx.columns.qualifiedName.asScala.map(typedVisit[Seq[String]]) + AlterTableDropColumnsStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + columnsToDrop) + } + + /** + * Parse [[AlterViewSetPropertiesStatement]] or [[AlterTableSetPropertiesStatement]] commands. + * + * For example: + * {{{ + * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); + * }}} + */ + override def visitSetTableProperties( + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + val identifier = visitMultipartIdentifier(ctx.multipartIdentifier) + val properties = visitPropertyKeyValues(ctx.tablePropertyList) + if (ctx.VIEW != null) { + AlterViewSetPropertiesStatement(identifier, properties) + } else { + AlterTableSetPropertiesStatement(identifier, properties) + } + } + + /** + * Parse [[AlterViewUnsetPropertiesStatement]] or [[AlterTableUnsetPropertiesStatement]] commands. + * + * For example: + * {{{ + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * }}} + */ + override def visitUnsetTableProperties( + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + val identifier = visitMultipartIdentifier(ctx.multipartIdentifier) + val properties = visitPropertyKeys(ctx.tablePropertyList) + val ifExists = ctx.EXISTS != null + if (ctx.VIEW != null) { + AlterViewUnsetPropertiesStatement(identifier, properties, ifExists) + } else { + AlterTableUnsetPropertiesStatement(identifier, properties, ifExists) + } + } + + /** + * Create an [[AlterTableSetLocationStatement]] command. + * + * For example: + * {{{ + * ALTER TABLE table SET LOCATION "loc"; + * }}} + */ + override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetLocationStatement( + visitMultipartIdentifier(ctx.multipartIdentifier), + visitLocationSpec(ctx.locationSpec)) + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala new file mode 100644 index 0000000000000..9d7dec9ae0ce0 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +import org.apache.spark.sql.types.DataType + +/** + * Column data as parsed by ALTER TABLE ... ADD COLUMNS. + */ +case class QualifiedColType(name: Seq[String], dataType: DataType, comment: Option[String]) + +/** + * ALTER TABLE ... ADD COLUMNS command, as parsed from SQL. + */ +case class AlterTableAddColumnsStatement( + tableName: Seq[String], + columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement + +/** + * ALTER TABLE ... CHANGE COLUMN command, as parsed from SQL. + */ +case class AlterTableAlterColumnStatement( + tableName: Seq[String], + column: Seq[String], + dataType: Option[DataType], + comment: Option[String]) extends ParsedStatement + +/** + * ALTER TABLE ... RENAME COLUMN command, as parsed from SQL. + */ +case class AlterTableRenameColumnStatement( + tableName: Seq[String], + column: Seq[String], + newName: String) extends ParsedStatement + +/** + * ALTER TABLE ... DROP COLUMNS command, as parsed from SQL. + */ +case class AlterTableDropColumnsStatement( + tableName: Seq[String], + columnsToDrop: Seq[Seq[String]]) extends ParsedStatement + +/** + * ALTER TABLE ... SET TBLPROPERTIES command, as parsed from SQL. + */ +case class AlterTableSetPropertiesStatement( + tableName: Seq[String], + properties: Map[String, String]) extends ParsedStatement + +/** + * ALTER TABLE ... UNSET TBLPROPERTIES command, as parsed from SQL. + */ +case class AlterTableUnsetPropertiesStatement( + tableName: Seq[String], + propertyKeys: Seq[String], + ifExists: Boolean) extends ParsedStatement + +/** + * ALTER TABLE ... SET LOCATION command, as parsed from SQL. + */ +case class AlterTableSetLocationStatement( + tableName: Seq[String], + location: String) extends ParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala new file mode 100644 index 0000000000000..bba7f12c94e50 --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.plans.logical.sql + +/** + * ALTER VIEW ... SET TBLPROPERTIES command, as parsed from SQL. + */ +case class AlterViewSetPropertiesStatement( + viewName: Seq[String], + properties: Map[String, String]) extends ParsedStatement + +/** + * ALTER VIEW ... UNSET TBLPROPERTIES command, as parsed from SQL. + */ +case class AlterViewUnsetPropertiesStatement( + viewName: Seq[String], + propertyKeys: Seq[String], + ifExists: Boolean) extends ParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala index 7a26e01cde830..190711303e32d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.plans.logical.sql import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.catalog.BucketSpec -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType @@ -38,12 +37,7 @@ case class CreateTableStatement( options: Map[String, String], location: Option[String], comment: Option[String], - ifNotExists: Boolean) extends ParsedStatement { - - override def output: Seq[Attribute] = Seq.empty - - override def children: Seq[LogicalPlan] = Seq.empty -} + ifNotExists: Boolean) extends ParsedStatement /** * A CREATE TABLE AS SELECT command, as parsed from SQL. @@ -60,7 +54,5 @@ case class CreateTableAsSelectStatement( comment: Option[String], ifNotExists: Boolean) extends ParsedStatement { - override def output: Seq[Attribute] = Seq.empty - override def children: Seq[LogicalPlan] = Seq(asSelect) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala index 510f2a1ba1e6d..2942c4b1fcca5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.catalyst.plans.logical.sql +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** @@ -40,5 +41,9 @@ private[sql] abstract class ParsedStatement extends LogicalPlan { case other => other } + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty + final override lazy val resolved = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 35cd813ae65c5..39fd8afe13ff4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -17,23 +17,43 @@ package org.apache.spark.sql.catalyst.parser +import java.util.Locale + import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} -import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String class DDLParserSuite extends AnalysisTest { import CatalystSqlParser._ +<<<<<<< HEAD private def intercept(sqlCommand: String, messages: String*): Unit = { val e = intercept[ParseException](parsePlan(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } } +||||||| parent of fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst + private def intercept(sqlCommand: String, messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*) +======= + private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { + val e = intercept[ParseException] { + parsePlan(sql) + } + assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) + containsThesePhrases.foreach { p => + assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) + } + } + + private def intercept(sqlCommand: String, messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*) +>>>>>>> fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst private def parseCompare(sql: String, expected: LogicalPlan): Unit = { comparePlans(parsePlan(sql), expected, checkAnalysis = false) @@ -394,4 +414,195 @@ class DDLParserSuite extends AnalysisTest { parseCompare(s"DROP VIEW view", DropViewStatement(Seq("view"), ifExists = false)) parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true)) } + + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter view: alter view properties") { + val sql1_view = "ALTER VIEW table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_view = "ALTER VIEW table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_view = "ALTER VIEW table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + + comparePlans(parsePlan(sql1_view), + AlterViewSetPropertiesStatement( + Seq("table_name"), Map("test" -> "test", "comment" -> "new_comment"))) + comparePlans(parsePlan(sql2_view), + AlterViewUnsetPropertiesStatement( + Seq("table_name"), Seq("comment", "test"), ifExists = false)) + comparePlans(parsePlan(sql3_view), + AlterViewUnsetPropertiesStatement( + Seq("table_name"), Seq("comment", "test"), ifExists = true)) + } + + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table: alter table properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + + comparePlans( + parsePlan(sql1_table), + AlterTableSetPropertiesStatement( + Seq("table_name"), Map("test" -> "test", "comment" -> "new_comment"))) + comparePlans( + parsePlan(sql2_table), + AlterTableUnsetPropertiesStatement( + Seq("table_name"), Seq("comment", "test"), ifExists = false)) + comparePlans( + parsePlan(sql3_table), + AlterTableUnsetPropertiesStatement( + Seq("table_name"), Seq("comment", "test"), ifExists = true)) + } + + test("alter table: add column") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x int"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None) + ))) + } + + test("alter table: add multiple columns") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None), + QualifiedColType(Seq("y"), StringType, None) + ))) + } + + test("alter table: add column with COLUMNS") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMNS x int"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None) + ))) + } + + test("alter table: add column with COLUMNS (...)") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, None) + ))) + } + + test("alter table: add column with COLUMNS (...) and COMMENT") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, Some("doc")) + ))) + } + + test("alter table: add column with COMMENT") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x"), IntegerType, Some("doc")) + ))) + } + + test("alter table: add column with nested column name") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")) + ))) + } + + test("alter table: add multiple columns with nested column name") { + comparePlans( + parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string"), + AlterTableAddColumnsStatement(Seq("table_name"), Seq( + QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")), + QualifiedColType(Seq("a", "b"), StringType, None) + ))) + } + + test("alter table: add column at position (not supported)") { + assertUnsupported("ALTER TABLE table_name ADD COLUMNS name bigint COMMENT 'doc' FIRST, a.b int") + assertUnsupported("ALTER TABLE table_name ADD COLUMN name bigint COMMENT 'doc' FIRST") + assertUnsupported("ALTER TABLE table_name ADD COLUMN name string AFTER a.b") + } + + test("alter table: set location") { + val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" + val parsed1 = parsePlan(sql1) + val expected1 = AlterTableSetLocationStatement(Seq("table_name"), "new location") + comparePlans(parsed1, expected1) + } + + test("alter table: rename column") { + comparePlans( + parsePlan("ALTER TABLE table_name RENAME COLUMN a.b.c TO d"), + AlterTableRenameColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + "d")) + } + + test("alter table: update column type using ALTER") { + comparePlans( + parsePlan("ALTER TABLE table_name ALTER COLUMN a.b.c TYPE bigint"), + AlterTableAlterColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + Some(LongType), + None)) + } + + test("alter table: update column type") { + comparePlans( + parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c TYPE bigint"), + AlterTableAlterColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + Some(LongType), + None)) + } + + test("alter table: update column comment") { + comparePlans( + parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c COMMENT 'new comment'"), + AlterTableAlterColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + None, + Some("new comment"))) + } + + test("alter table: update column type and comment") { + comparePlans( + parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c TYPE bigint COMMENT 'new comment'"), + AlterTableAlterColumnStatement( + Seq("table_name"), + Seq("a", "b", "c"), + Some(LongType), + Some("new comment"))) + } + + test("alter table: change column position (not supported)") { + assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name COMMENT 'doc' FIRST") + assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name TYPE INT AFTER other_col") + } + + test("alter table: drop column") { + comparePlans( + parsePlan("ALTER TABLE table_name DROP COLUMN a.b.c"), + AlterTableDropColumnsStatement(Seq("table_name"), Seq(Seq("a", "b", "c")))) + } + + test("alter table: drop multiple columns") { + val sql = "ALTER TABLE table_name DROP COLUMN x, y, a.b.c" + Seq(sql, sql.replace("COLUMN", "COLUMNS")).foreach { drop => + comparePlans( + parsePlan(drop), + AlterTableDropColumnsStatement( + Seq("table_name"), + Seq(Seq("x"), Seq("y"), Seq("a", "b", "c")))) + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index ac61661e83e32..f33abdda47522 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -662,57 +662,6 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx.VIEW != null) } - /** - * Create a [[AlterTableAddColumnsCommand]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 - * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); - * }}} - */ - override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableAddColumnsCommand( - visitTableIdentifier(ctx.tableIdentifier), - visitColTypeList(ctx.columns) - ) - } - - /** - * Create an [[AlterTableSetPropertiesCommand]] command. - * - * For example: - * {{{ - * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); - * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); - * }}} - */ - override def visitSetTableProperties( - ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetPropertiesCommand( - visitTableIdentifier(ctx.tableIdentifier), - visitPropertyKeyValues(ctx.tablePropertyList), - ctx.VIEW != null) - } - - /** - * Create an [[AlterTableUnsetPropertiesCommand]] command. - * - * For example: - * {{{ - * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * }}} - */ - override def visitUnsetTableProperties( - ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - AlterTableUnsetPropertiesCommand( - visitTableIdentifier(ctx.tableIdentifier), - visitPropertyKeys(ctx.tablePropertyList), - ctx.EXISTS != null, - ctx.VIEW != null) - } - /** * Create an [[AlterTableSerDePropertiesCommand]] command. * @@ -821,17 +770,18 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create an [[AlterTableSetLocationCommand]] command + * Create an [[AlterTableSetLocationCommand]] command for a partition. * * For example: * {{{ - * ALTER TABLE table [PARTITION spec] SET LOCATION "loc"; + * ALTER TABLE table PARTITION spec SET LOCATION "loc"; * }}} */ - override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { + override def visitSetPartitionLocation( + ctx: SetPartitionLocationContext): LogicalPlan = withOrigin(ctx) { AlterTableSetLocationCommand( visitTableIdentifier(ctx.tableIdentifier), - Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), + Some(visitNonOptionalPartitionSpec(ctx.partitionSpec)), visitLocationSpec(ctx.locationSpec)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7d34b6568a4fc..c5f8cf24fe7d0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.command.DropTableCommand +import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider -import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} case class DataSourceResolution( conf: SQLConf, @@ -96,6 +96,26 @@ case class DataSourceResolution( case DropViewStatement(AsTableIdentifier(tableName), ifExists) => DropTableCommand(tableName, ifExists, isView = true, purge = false) + + case AlterTableSetPropertiesStatement(AsTableIdentifier(table), properties) => + AlterTableSetPropertiesCommand(table, properties, isView = false) + + case AlterViewSetPropertiesStatement(AsTableIdentifier(table), properties) => + AlterTableSetPropertiesCommand(table, properties, isView = true) + + case AlterTableUnsetPropertiesStatement(AsTableIdentifier(table), propertyKeys, ifExists) => + AlterTableUnsetPropertiesCommand(table, propertyKeys, ifExists, isView = false) + + case AlterViewUnsetPropertiesStatement(AsTableIdentifier(table), propertyKeys, ifExists) => + AlterTableUnsetPropertiesCommand(table, propertyKeys, ifExists, isView = true) + + case AlterTableSetLocationStatement(AsTableIdentifier(table), newLocation) => + AlterTableSetLocationCommand(table, None, newLocation) + + case AlterTableAddColumnsStatement(AsTableIdentifier(table), newColumns) + if newColumns.forall(_.name.size == 1) => + // only top-level adds are supported using AlterTableAddColumnsCommand + AlterTableAddColumnsCommand(table, newColumns.map(convertToStructField)) } object V1WriteProvider { @@ -231,4 +251,20 @@ case class DataSourceResolution( tableProperties.toMap } + + private def convertToStructField(col: QualifiedColType): StructField = { + val builder = new MetadataBuilder + col.comment.foreach(builder.putString("comment", _)) + + val cleanedDataType = HiveStringType.replaceCharType(col.dataType) + if (col.dataType != cleanedDataType) { + builder.putString(HIVE_TYPE_STRING, col.dataType.catalogString) + } + + StructField( + col.name.head, + cleanedDataType, + nullable = true, + builder.build()) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 0dd11c1e518e0..8203c900329c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -511,45 +511,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(plan.newName == TableIdentifier("tbl2", Some("db1"))) } - // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - test("alter table/view: alter table/view properties") { - val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + - "'comment' = 'new_comment')" - val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - val sql1_view = sql1_table.replace("TABLE", "VIEW") - val sql2_view = sql2_table.replace("TABLE", "VIEW") - val sql3_view = sql3_table.replace("TABLE", "VIEW") - - val parsed1_table = parser.parsePlan(sql1_table) - val parsed2_table = parser.parsePlan(sql2_table) - val parsed3_table = parser.parsePlan(sql3_table) - val parsed1_view = parser.parsePlan(sql1_view) - val parsed2_view = parser.parsePlan(sql2_view) - val parsed3_view = parser.parsePlan(sql3_view) - - val tableIdent = TableIdentifier("table_name", None) - val expected1_table = AlterTableSetPropertiesCommand( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) - val expected2_table = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = false, isView = false) - val expected3_table = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = true, isView = false) - val expected1_view = expected1_table.copy(isView = true) - val expected2_view = expected2_table.copy(isView = true) - val expected3_view = expected3_table.copy(isView = true) - - comparePlans(parsed1_table, expected1_table) - comparePlans(parsed2_table, expected2_table) - comparePlans(parsed3_table, expected3_table) - comparePlans(parsed1_view, expected1_view) - comparePlans(parsed2_view, expected2_view) - comparePlans(parsed3_view, expected3_view) - } - test("alter table - property values must be set") { assertUnsupported( sql = "ALTER TABLE my_tab SET TBLPROPERTIES('key_without_value', 'key_with_value'='x')", @@ -747,22 +708,15 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { "SET FILEFORMAT PARQUET") } - test("alter table: set location") { - val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" + test("alter table: set partition location") { val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET LOCATION 'new location'" - val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetLocationCommand( - tableIdent, - None, - "new location") val expected2 = AlterTableSetLocationCommand( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), "new location") - comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } @@ -946,21 +900,6 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { comparePlans(parsed, expected) } - test("support for other types in TBLPROPERTIES") { - val sql = - """ - |ALTER TABLE table_name - |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE) - """.stripMargin - val parsed = parser.parsePlan(sql) - val expected = AlterTableSetPropertiesCommand( - TableIdentifier("table_name"), - Map("a" -> "1", "b" -> "0.1", "c" -> "true"), - isView = false) - - comparePlans(parsed, expected) - } - test("Test CTAS #1") { val s1 = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 60801910c6dbc..a834932110896 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -502,4 +502,80 @@ class PlanResolutionSuite extends AnalysisTest { }.getMessage.toLowerCase(Locale.ROOT).contains( "view support in catalog has not been implemented") } + + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter view: alter view properties") { + val sql1_view = "ALTER VIEW table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_view = "ALTER VIEW table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_view = "ALTER VIEW table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + + val parsed1_view = parseAndResolve(sql1_view) + val parsed2_view = parseAndResolve(sql2_view) + val parsed3_view = parseAndResolve(sql3_view) + + val tableIdent = TableIdentifier("table_name", None) + val expected1_view = AlterTableSetPropertiesCommand( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = true) + val expected2_view = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = false, isView = true) + val expected3_view = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = true, isView = true) + + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) + comparePlans(parsed3_view, expected3_view) + } + + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table: alter table properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + + val parsed1_table = parseAndResolve(sql1_table) + val parsed2_table = parseAndResolve(sql2_table) + val parsed3_table = parseAndResolve(sql3_table) + + val tableIdent = TableIdentifier("table_name", None) + val expected1_table = AlterTableSetPropertiesCommand( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) + val expected2_table = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = false, isView = false) + val expected3_table = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = true, isView = false) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed3_table, expected3_table) + } + + test("support for other types in TBLPROPERTIES") { + val sql = + """ + |ALTER TABLE table_name + |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE) + """.stripMargin + val parsed = parseAndResolve(sql) + val expected = AlterTableSetPropertiesCommand( + TableIdentifier("table_name"), + Map("a" -> "1", "b" -> "0.1", "c" -> "true"), + isView = false) + + comparePlans(parsed, expected) + } + + test("alter table: set location") { + val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" + val parsed1 = parseAndResolve(sql1) + val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableSetLocationCommand( + tableIdent, + None, + "new location") + comparePlans(parsed1, expected1) + } } From 1a45142b707a1f0bc805320ed088a89210162cc1 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 14:35:43 -0700 Subject: [PATCH 49/70] Fix merge conflicts --- .../spark/sql/catalyst/parser/DDLParserSuite.scala | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 39fd8afe13ff4..d008b3c78fac3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -30,17 +30,6 @@ import org.apache.spark.unsafe.types.UTF8String class DDLParserSuite extends AnalysisTest { import CatalystSqlParser._ -<<<<<<< HEAD - private def intercept(sqlCommand: String, messages: String*): Unit = { - val e = intercept[ParseException](parsePlan(sqlCommand)) - messages.foreach { message => - assert(e.message.contains(message)) - } - } -||||||| parent of fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst - private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) -======= private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { parsePlan(sql) @@ -53,7 +42,6 @@ class DDLParserSuite extends AnalysisTest { private def intercept(sqlCommand: String, messages: String*): Unit = interceptParseException(parsePlan)(sqlCommand, messages: _*) ->>>>>>> fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst private def parseCompare(sql: String, expected: LogicalPlan): Unit = { comparePlans(parsePlan(sql), expected, checkAnalysis = false) From 8bcc74dfff4625f10ab7e35ab25ae969395ed726 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 14:47:36 -0700 Subject: [PATCH 50/70] Revert "Fix merge conflicts" This reverts commit 1a45142b707a1f0bc805320ed088a89210162cc1. --- .../spark/sql/catalyst/parser/DDLParserSuite.scala | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index d008b3c78fac3..39fd8afe13ff4 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -30,6 +30,17 @@ import org.apache.spark.unsafe.types.UTF8String class DDLParserSuite extends AnalysisTest { import CatalystSqlParser._ +<<<<<<< HEAD + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parsePlan(sqlCommand)) + messages.foreach { message => + assert(e.message.contains(message)) + } + } +||||||| parent of fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst + private def intercept(sqlCommand: String, messages: String*): Unit = + interceptParseException(parsePlan)(sqlCommand, messages: _*) +======= private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { val e = intercept[ParseException] { parsePlan(sql) @@ -42,6 +53,7 @@ class DDLParserSuite extends AnalysisTest { private def intercept(sqlCommand: String, messages: String*): Unit = interceptParseException(parsePlan)(sqlCommand, messages: _*) +>>>>>>> fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst private def parseCompare(sql: String, expected: LogicalPlan): Unit = { comparePlans(parsePlan(sql), expected, checkAnalysis = false) From a3debfddc4f24ef13d3615ab9be8ab854097b679 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 14:47:39 -0700 Subject: [PATCH 51/70] Revert "[SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst" This reverts commit e1365ba26e4fe04c8bbc891ec628195fb87d1cde. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 98 +------- .../sql/catalyst/parser/AstBuilder.scala | 162 +------------ .../logical/sql/AlterTableStatements.scala | 78 ------- .../logical/sql/AlterViewStatements.scala | 33 --- .../logical/sql/CreateTableStatement.scala | 10 +- .../plans/logical/sql/ParsedStatement.scala | 5 - .../sql/catalyst/parser/DDLParserSuite.scala | 215 +----------------- .../spark/sql/execution/SparkSqlParser.scala | 60 ++++- .../datasources/DataSourceResolution.scala | 42 +--- .../execution/command/DDLParserSuite.scala | 63 ++++- .../command/PlanResolutionSuite.scala | 76 ------- 11 files changed, 141 insertions(+), 701 deletions(-) delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala delete mode 100644 sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 706e96318d5a7..04fbdd2ddd15f 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -110,27 +110,14 @@ statement LIKE source=tableIdentifier locationSpec? #createTableLike | ANALYZE TABLE tableIdentifier partitionSpec? COMPUTE STATISTICS (identifier | FOR COLUMNS identifierSeq | FOR ALL COLUMNS)? #analyze - | ALTER TABLE multipartIdentifier - ADD (COLUMN | COLUMNS) - columns=qualifiedColTypeWithPositionList #addTableColumns - | ALTER TABLE multipartIdentifier - ADD (COLUMN | COLUMNS) - '(' columns=qualifiedColTypeWithPositionList ')' #addTableColumns - | ALTER TABLE multipartIdentifier - RENAME COLUMN from=qualifiedName TO to=identifier #renameTableColumn - | ALTER TABLE multipartIdentifier - DROP (COLUMN | COLUMNS) '(' columns=qualifiedNameList ')' #dropTableColumns - | ALTER TABLE multipartIdentifier - DROP (COLUMN | COLUMNS) columns=qualifiedNameList #dropTableColumns + | ALTER TABLE tableIdentifier + ADD COLUMNS '(' columns=colTypeList ')' #addTableColumns | ALTER (TABLE | VIEW) from=tableIdentifier RENAME TO to=tableIdentifier #renameTable - | ALTER (TABLE | VIEW) multipartIdentifier + | ALTER (TABLE | VIEW) tableIdentifier SET TBLPROPERTIES tablePropertyList #setTableProperties - | ALTER (TABLE | VIEW) multipartIdentifier + | ALTER (TABLE | VIEW) tableIdentifier UNSET TBLPROPERTIES (IF EXISTS)? tablePropertyList #unsetTableProperties - | ALTER TABLE multipartIdentifier - (ALTER | CHANGE) COLUMN? qualifiedName - (TYPE dataType)? (COMMENT comment=STRING)? colPosition? #alterTableColumn | ALTER TABLE tableIdentifier partitionSpec? CHANGE COLUMN? identifier colType colPosition? #changeColumn | ALTER TABLE tableIdentifier (partitionSpec)? @@ -147,8 +134,7 @@ statement DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* PURGE? #dropTablePartitions | ALTER VIEW tableIdentifier DROP (IF EXISTS)? partitionSpec (',' partitionSpec)* #dropTablePartitions - | ALTER TABLE multipartIdentifier SET locationSpec #setTableLocation - | ALTER TABLE tableIdentifier partitionSpec SET locationSpec #setPartitionLocation + | ALTER TABLE tableIdentifier partitionSpec? SET locationSpec #setTableLocation | ALTER TABLE tableIdentifier RECOVER PARTITIONS #recoverPartitions | DROP TABLE (IF EXISTS)? multipartIdentifier PURGE? #dropTable | DROP VIEW (IF EXISTS)? multipartIdentifier #dropView @@ -704,7 +690,7 @@ intervalValue ; colPosition - : FIRST | AFTER qualifiedName + : FIRST | AFTER identifier ; dataType @@ -714,14 +700,6 @@ dataType | identifier ('(' INTEGER_VALUE (',' INTEGER_VALUE)* ')')? #primitiveDataType ; -qualifiedColTypeWithPositionList - : qualifiedColTypeWithPosition (',' qualifiedColTypeWithPosition)* - ; - -qualifiedColTypeWithPosition - : name=qualifiedName dataType (COMMENT comment=STRING)? colPosition? - ; - colTypeList : colType (',' colType)* ; @@ -774,10 +752,6 @@ frameBound | expression boundType=(PRECEDING | FOLLOWING) ; -qualifiedNameList - : qualifiedName (',' qualifiedName)* - ; - qualifiedName : identifier ('.' identifier)* ; @@ -1279,7 +1253,6 @@ nonReserved | TRANSFORM | TRUE | TRUNCATE - | TYPE | UNARCHIVE | UNBOUNDED | UNCACHE @@ -1406,7 +1379,6 @@ RESET: 'RESET'; DATA: 'DATA'; START: 'START'; TRANSACTION: 'TRANSACTION'; -<<<<<<< HEAD COMMIT: 'COMMIT'; ROLLBACK: 'ROLLBACK'; MACRO: 'MACRO'; @@ -1418,64 +1390,6 @@ TRAILING: 'TRAILING'; IF: 'IF'; POSITION: 'POSITION'; EXTRACT: 'EXTRACT'; -||||||| parent of 5d6758c0e7... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst -TRANSACTIONS: 'TRANSACTIONS'; -TRANSFORM: 'TRANSFORM'; -TRUE: 'TRUE'; -TRUNCATE: 'TRUNCATE'; -UNARCHIVE: 'UNARCHIVE'; -UNBOUNDED: 'UNBOUNDED'; -UNCACHE: 'UNCACHE'; -UNION: 'UNION'; -UNIQUE: 'UNIQUE'; -UNLOCK: 'UNLOCK'; -UNSET: 'UNSET'; -USE: 'USE'; -USER: 'USER'; -USING: 'USING'; -VALUES: 'VALUES'; -VIEW: 'VIEW'; -WEEK: 'WEEK'; -WEEKS: 'WEEKS'; -WHEN: 'WHEN'; -WHERE: 'WHERE'; -WINDOW: 'WINDOW'; -WITH: 'WITH'; -YEAR: 'YEAR'; -YEARS: 'YEARS'; -//============================ -// End of the keywords list -//============================ -======= -TRANSACTIONS: 'TRANSACTIONS'; -TRANSFORM: 'TRANSFORM'; -TRUE: 'TRUE'; -TRUNCATE: 'TRUNCATE'; -TYPE: 'TYPE'; -UNARCHIVE: 'UNARCHIVE'; -UNBOUNDED: 'UNBOUNDED'; -UNCACHE: 'UNCACHE'; -UNION: 'UNION'; -UNIQUE: 'UNIQUE'; -UNLOCK: 'UNLOCK'; -UNSET: 'UNSET'; -USE: 'USE'; -USER: 'USER'; -USING: 'USING'; -VALUES: 'VALUES'; -VIEW: 'VIEW'; -WEEK: 'WEEK'; -WEEKS: 'WEEKS'; -WHEN: 'WHEN'; -WHERE: 'WHERE'; -WINDOW: 'WINDOW'; -WITH: 'WITH'; -YEAR: 'YEAR'; -YEARS: 'YEARS'; -//============================ -// End of the keywords list -//============================ ->>>>>>> 5d6758c0e7... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst EQ : '=' | '=='; NSEQ: '<=>'; diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index 4c1914841b58e..81ec2a1d9c904 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.{First, Last} import org.apache.spark.sql.catalyst.parser.SqlBaseParser._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} import org.apache.spark.sql.catalyst.util.DateTimeUtils.{getZoneId, stringToDate, stringToTimestamp} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ @@ -1991,13 +1991,6 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging (multipartIdentifier, temporary, ifNotExists, ctx.EXTERNAL != null) } - /** - * Parse a qualified name to a multipart name. - */ - override def visitQualifiedName(ctx: QualifiedNameContext): Seq[String] = withOrigin(ctx) { - ctx.identifier.asScala.map(_.getText) - } - /** * Parse a list of transforms. */ @@ -2030,7 +2023,8 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging ctx.transforms.asScala.map { case identityCtx: IdentityTransformContext => - IdentityTransform(FieldReference(typedVisit[Seq[String]](identityCtx.qualifiedName))) + IdentityTransform(FieldReference( + identityCtx.qualifiedName.identifier.asScala.map(_.getText))) case applyCtx: ApplyTransformContext => val arguments = applyCtx.argument.asScala.map(visitTransformArgument) @@ -2077,8 +2071,7 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging override def visitTransformArgument(ctx: TransformArgumentContext): v2.expressions.Expression = { withOrigin(ctx) { val reference = Option(ctx.qualifiedName) - .map(typedVisit[Seq[String]]) - .map(FieldReference(_)) + .map(nameCtx => FieldReference(nameCtx.identifier.asScala.map(_.getText))) val literal = Option(ctx.constant) .map(typedVisit[Literal]) .map(lit => LiteralValue(lit.value, lit.dataType)) @@ -2176,151 +2169,4 @@ class AstBuilder(conf: SQLConf) extends SqlBaseBaseVisitor[AnyRef] with Logging visitMultipartIdentifier(ctx.multipartIdentifier()), ctx.EXISTS != null) } - - /** - * Parse new column info from ADD COLUMN into a QualifiedColType. - */ - override def visitQualifiedColTypeWithPosition( - ctx: QualifiedColTypeWithPositionContext): QualifiedColType = withOrigin(ctx) { - if (ctx.colPosition != null) { - operationNotAllowed("ALTER TABLE table ADD COLUMN ... FIRST | AFTER otherCol", ctx) - } - - QualifiedColType( - typedVisit[Seq[String]](ctx.name), - typedVisit[DataType](ctx.dataType), - Option(ctx.comment).map(string)) - } - - /** - * Parse a [[AlterTableAddColumnsStatement]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 - * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); - * }}} - */ - override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { - AlterTableAddColumnsStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), - ctx.columns.qualifiedColTypeWithPosition.asScala.map(typedVisit[QualifiedColType]) - ) - } - - /** - * Parse a [[AlterTableRenameColumnStatement]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 RENAME COLUMN a.b.c TO x - * }}} - */ - override def visitRenameTableColumn( - ctx: RenameTableColumnContext): LogicalPlan = withOrigin(ctx) { - AlterTableRenameColumnStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), - ctx.from.identifier.asScala.map(_.getText), - ctx.to.getText) - } - - /** - * Parse a [[AlterTableAlterColumnStatement]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint - * ALTER TABLE table1 ALTER COLUMN a.b.c TYPE bigint COMMENT 'new comment' - * ALTER TABLE table1 ALTER COLUMN a.b.c COMMENT 'new comment' - * }}} - */ - override def visitAlterTableColumn( - ctx: AlterTableColumnContext): LogicalPlan = withOrigin(ctx) { - val verb = if (ctx.CHANGE != null) "CHANGE" else "ALTER" - if (ctx.colPosition != null) { - operationNotAllowed(s"ALTER TABLE table $verb COLUMN ... FIRST | AFTER otherCol", ctx) - } - - if (ctx.dataType == null && ctx.comment == null) { - operationNotAllowed(s"ALTER TABLE table $verb COLUMN requires a TYPE or a COMMENT", ctx) - } - - AlterTableAlterColumnStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), - typedVisit[Seq[String]](ctx.qualifiedName), - Option(ctx.dataType).map(typedVisit[DataType]), - Option(ctx.comment).map(string)) - } - - /** - * Parse a [[AlterTableDropColumnsStatement]] command. - * - * For example: - * {{{ - * ALTER TABLE table1 DROP COLUMN a.b.c - * ALTER TABLE table1 DROP COLUMNS a.b.c, x, y - * }}} - */ - override def visitDropTableColumns( - ctx: DropTableColumnsContext): LogicalPlan = withOrigin(ctx) { - val columnsToDrop = ctx.columns.qualifiedName.asScala.map(typedVisit[Seq[String]]) - AlterTableDropColumnsStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), - columnsToDrop) - } - - /** - * Parse [[AlterViewSetPropertiesStatement]] or [[AlterTableSetPropertiesStatement]] commands. - * - * For example: - * {{{ - * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); - * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); - * }}} - */ - override def visitSetTableProperties( - ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - val identifier = visitMultipartIdentifier(ctx.multipartIdentifier) - val properties = visitPropertyKeyValues(ctx.tablePropertyList) - if (ctx.VIEW != null) { - AlterViewSetPropertiesStatement(identifier, properties) - } else { - AlterTableSetPropertiesStatement(identifier, properties) - } - } - - /** - * Parse [[AlterViewUnsetPropertiesStatement]] or [[AlterTableUnsetPropertiesStatement]] commands. - * - * For example: - * {{{ - * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - * }}} - */ - override def visitUnsetTableProperties( - ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { - val identifier = visitMultipartIdentifier(ctx.multipartIdentifier) - val properties = visitPropertyKeys(ctx.tablePropertyList) - val ifExists = ctx.EXISTS != null - if (ctx.VIEW != null) { - AlterViewUnsetPropertiesStatement(identifier, properties, ifExists) - } else { - AlterTableUnsetPropertiesStatement(identifier, properties, ifExists) - } - } - - /** - * Create an [[AlterTableSetLocationStatement]] command. - * - * For example: - * {{{ - * ALTER TABLE table SET LOCATION "loc"; - * }}} - */ - override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { - AlterTableSetLocationStatement( - visitMultipartIdentifier(ctx.multipartIdentifier), - visitLocationSpec(ctx.locationSpec)) - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala deleted file mode 100644 index 9d7dec9ae0ce0..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterTableStatements.scala +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.sql - -import org.apache.spark.sql.types.DataType - -/** - * Column data as parsed by ALTER TABLE ... ADD COLUMNS. - */ -case class QualifiedColType(name: Seq[String], dataType: DataType, comment: Option[String]) - -/** - * ALTER TABLE ... ADD COLUMNS command, as parsed from SQL. - */ -case class AlterTableAddColumnsStatement( - tableName: Seq[String], - columnsToAdd: Seq[QualifiedColType]) extends ParsedStatement - -/** - * ALTER TABLE ... CHANGE COLUMN command, as parsed from SQL. - */ -case class AlterTableAlterColumnStatement( - tableName: Seq[String], - column: Seq[String], - dataType: Option[DataType], - comment: Option[String]) extends ParsedStatement - -/** - * ALTER TABLE ... RENAME COLUMN command, as parsed from SQL. - */ -case class AlterTableRenameColumnStatement( - tableName: Seq[String], - column: Seq[String], - newName: String) extends ParsedStatement - -/** - * ALTER TABLE ... DROP COLUMNS command, as parsed from SQL. - */ -case class AlterTableDropColumnsStatement( - tableName: Seq[String], - columnsToDrop: Seq[Seq[String]]) extends ParsedStatement - -/** - * ALTER TABLE ... SET TBLPROPERTIES command, as parsed from SQL. - */ -case class AlterTableSetPropertiesStatement( - tableName: Seq[String], - properties: Map[String, String]) extends ParsedStatement - -/** - * ALTER TABLE ... UNSET TBLPROPERTIES command, as parsed from SQL. - */ -case class AlterTableUnsetPropertiesStatement( - tableName: Seq[String], - propertyKeys: Seq[String], - ifExists: Boolean) extends ParsedStatement - -/** - * ALTER TABLE ... SET LOCATION command, as parsed from SQL. - */ -case class AlterTableSetLocationStatement( - tableName: Seq[String], - location: String) extends ParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala deleted file mode 100644 index bba7f12c94e50..0000000000000 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/AlterViewStatements.scala +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.plans.logical.sql - -/** - * ALTER VIEW ... SET TBLPROPERTIES command, as parsed from SQL. - */ -case class AlterViewSetPropertiesStatement( - viewName: Seq[String], - properties: Map[String, String]) extends ParsedStatement - -/** - * ALTER VIEW ... UNSET TBLPROPERTIES command, as parsed from SQL. - */ -case class AlterViewUnsetPropertiesStatement( - viewName: Seq[String], - propertyKeys: Seq[String], - ifExists: Boolean) extends ParsedStatement diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala index 190711303e32d..7a26e01cde830 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/CreateTableStatement.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.plans.logical.sql import org.apache.spark.sql.catalog.v2.expressions.Transform import org.apache.spark.sql.catalyst.catalog.BucketSpec +import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.types.StructType @@ -37,7 +38,12 @@ case class CreateTableStatement( options: Map[String, String], location: Option[String], comment: Option[String], - ifNotExists: Boolean) extends ParsedStatement + ifNotExists: Boolean) extends ParsedStatement { + + override def output: Seq[Attribute] = Seq.empty + + override def children: Seq[LogicalPlan] = Seq.empty +} /** * A CREATE TABLE AS SELECT command, as parsed from SQL. @@ -54,5 +60,7 @@ case class CreateTableAsSelectStatement( comment: Option[String], ifNotExists: Boolean) extends ParsedStatement { + override def output: Seq[Attribute] = Seq.empty + override def children: Seq[LogicalPlan] = Seq(asSelect) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala index 2942c4b1fcca5..510f2a1ba1e6d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/sql/ParsedStatement.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.catalyst.plans.logical.sql -import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan /** @@ -41,9 +40,5 @@ private[sql] abstract class ParsedStatement extends LogicalPlan { case other => other } - override def output: Seq[Attribute] = Seq.empty - - override def children: Seq[LogicalPlan] = Seq.empty - final override lazy val resolved = false } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala index 39fd8afe13ff4..35cd813ae65c5 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/DDLParserSuite.scala @@ -17,43 +17,23 @@ package org.apache.spark.sql.catalyst.parser -import java.util.Locale - import org.apache.spark.sql.catalog.v2.expressions.{ApplyTransform, BucketTransform, DaysTransform, FieldReference, HoursTransform, IdentityTransform, LiteralValue, MonthsTransform, YearsTransform} import org.apache.spark.sql.catalyst.analysis.AnalysisTest import org.apache.spark.sql.catalyst.catalog.BucketSpec import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableAlterColumnStatement, AlterTableDropColumnsStatement, AlterTableRenameColumnStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType, TimestampType} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} +import org.apache.spark.sql.types.{IntegerType, StringType, StructType, TimestampType} import org.apache.spark.unsafe.types.UTF8String class DDLParserSuite extends AnalysisTest { import CatalystSqlParser._ -<<<<<<< HEAD private def intercept(sqlCommand: String, messages: String*): Unit = { val e = intercept[ParseException](parsePlan(sqlCommand)) messages.foreach { message => assert(e.message.contains(message)) } } -||||||| parent of fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst - private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) -======= - private def assertUnsupported(sql: String, containsThesePhrases: Seq[String] = Seq()): Unit = { - val e = intercept[ParseException] { - parsePlan(sql) - } - assert(e.getMessage.toLowerCase(Locale.ROOT).contains("operation not allowed")) - containsThesePhrases.foreach { p => - assert(e.getMessage.toLowerCase(Locale.ROOT).contains(p.toLowerCase(Locale.ROOT))) - } - } - - private def intercept(sqlCommand: String, messages: String*): Unit = - interceptParseException(parsePlan)(sqlCommand, messages: _*) ->>>>>>> fad827a417... [SPARK-27857][SQL] Move ALTER TABLE parsing into Catalyst private def parseCompare(sql: String, expected: LogicalPlan): Unit = { comparePlans(parsePlan(sql), expected, checkAnalysis = false) @@ -414,195 +394,4 @@ class DDLParserSuite extends AnalysisTest { parseCompare(s"DROP VIEW view", DropViewStatement(Seq("view"), ifExists = false)) parseCompare(s"DROP VIEW IF EXISTS view", DropViewStatement(Seq("view"), ifExists = true)) } - - // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - test("alter view: alter view properties") { - val sql1_view = "ALTER VIEW table_name SET TBLPROPERTIES ('test' = 'test', " + - "'comment' = 'new_comment')" - val sql2_view = "ALTER VIEW table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3_view = "ALTER VIEW table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - - comparePlans(parsePlan(sql1_view), - AlterViewSetPropertiesStatement( - Seq("table_name"), Map("test" -> "test", "comment" -> "new_comment"))) - comparePlans(parsePlan(sql2_view), - AlterViewUnsetPropertiesStatement( - Seq("table_name"), Seq("comment", "test"), ifExists = false)) - comparePlans(parsePlan(sql3_view), - AlterViewUnsetPropertiesStatement( - Seq("table_name"), Seq("comment", "test"), ifExists = true)) - } - - // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - test("alter table: alter table properties") { - val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + - "'comment' = 'new_comment')" - val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - - comparePlans( - parsePlan(sql1_table), - AlterTableSetPropertiesStatement( - Seq("table_name"), Map("test" -> "test", "comment" -> "new_comment"))) - comparePlans( - parsePlan(sql2_table), - AlterTableUnsetPropertiesStatement( - Seq("table_name"), Seq("comment", "test"), ifExists = false)) - comparePlans( - parsePlan(sql3_table), - AlterTableUnsetPropertiesStatement( - Seq("table_name"), Seq("comment", "test"), ifExists = true)) - } - - test("alter table: add column") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMN x int"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) - ))) - } - - test("alter table: add multiple columns") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMNS x int, y string"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None), - QualifiedColType(Seq("y"), StringType, None) - ))) - } - - test("alter table: add column with COLUMNS") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMNS x int"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) - ))) - } - - test("alter table: add column with COLUMNS (...)") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMNS (x int)"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, None) - ))) - } - - test("alter table: add column with COLUMNS (...) and COMMENT") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMNS (x int COMMENT 'doc')"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, Some("doc")) - ))) - } - - test("alter table: add column with COMMENT") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMN x int COMMENT 'doc'"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x"), IntegerType, Some("doc")) - ))) - } - - test("alter table: add column with nested column name") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc'"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")) - ))) - } - - test("alter table: add multiple columns with nested column name") { - comparePlans( - parsePlan("ALTER TABLE table_name ADD COLUMN x.y.z int COMMENT 'doc', a.b string"), - AlterTableAddColumnsStatement(Seq("table_name"), Seq( - QualifiedColType(Seq("x", "y", "z"), IntegerType, Some("doc")), - QualifiedColType(Seq("a", "b"), StringType, None) - ))) - } - - test("alter table: add column at position (not supported)") { - assertUnsupported("ALTER TABLE table_name ADD COLUMNS name bigint COMMENT 'doc' FIRST, a.b int") - assertUnsupported("ALTER TABLE table_name ADD COLUMN name bigint COMMENT 'doc' FIRST") - assertUnsupported("ALTER TABLE table_name ADD COLUMN name string AFTER a.b") - } - - test("alter table: set location") { - val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" - val parsed1 = parsePlan(sql1) - val expected1 = AlterTableSetLocationStatement(Seq("table_name"), "new location") - comparePlans(parsed1, expected1) - } - - test("alter table: rename column") { - comparePlans( - parsePlan("ALTER TABLE table_name RENAME COLUMN a.b.c TO d"), - AlterTableRenameColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - "d")) - } - - test("alter table: update column type using ALTER") { - comparePlans( - parsePlan("ALTER TABLE table_name ALTER COLUMN a.b.c TYPE bigint"), - AlterTableAlterColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - Some(LongType), - None)) - } - - test("alter table: update column type") { - comparePlans( - parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c TYPE bigint"), - AlterTableAlterColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - Some(LongType), - None)) - } - - test("alter table: update column comment") { - comparePlans( - parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c COMMENT 'new comment'"), - AlterTableAlterColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - None, - Some("new comment"))) - } - - test("alter table: update column type and comment") { - comparePlans( - parsePlan("ALTER TABLE table_name CHANGE COLUMN a.b.c TYPE bigint COMMENT 'new comment'"), - AlterTableAlterColumnStatement( - Seq("table_name"), - Seq("a", "b", "c"), - Some(LongType), - Some("new comment"))) - } - - test("alter table: change column position (not supported)") { - assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name COMMENT 'doc' FIRST") - assertUnsupported("ALTER TABLE table_name CHANGE COLUMN name TYPE INT AFTER other_col") - } - - test("alter table: drop column") { - comparePlans( - parsePlan("ALTER TABLE table_name DROP COLUMN a.b.c"), - AlterTableDropColumnsStatement(Seq("table_name"), Seq(Seq("a", "b", "c")))) - } - - test("alter table: drop multiple columns") { - val sql = "ALTER TABLE table_name DROP COLUMN x, y, a.b.c" - Seq(sql, sql.replace("COLUMN", "COLUMNS")).foreach { drop => - comparePlans( - parsePlan(drop), - AlterTableDropColumnsStatement( - Seq("table_name"), - Seq(Seq("x"), Seq("y"), Seq("a", "b", "c")))) - } - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index f33abdda47522..ac61661e83e32 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -662,6 +662,57 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx.VIEW != null) } + /** + * Create a [[AlterTableAddColumnsCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table1 + * ADD COLUMNS (col_name data_type [COMMENT col_comment], ...); + * }}} + */ + override def visitAddTableColumns(ctx: AddTableColumnsContext): LogicalPlan = withOrigin(ctx) { + AlterTableAddColumnsCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitColTypeList(ctx.columns) + ) + } + + /** + * Create an [[AlterTableSetPropertiesCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table SET TBLPROPERTIES ('comment' = new_comment); + * ALTER VIEW view SET TBLPROPERTIES ('comment' = new_comment); + * }}} + */ + override def visitSetTableProperties( + ctx: SetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableSetPropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitPropertyKeyValues(ctx.tablePropertyList), + ctx.VIEW != null) + } + + /** + * Create an [[AlterTableUnsetPropertiesCommand]] command. + * + * For example: + * {{{ + * ALTER TABLE table UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * ALTER VIEW view UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + * }}} + */ + override def visitUnsetTableProperties( + ctx: UnsetTablePropertiesContext): LogicalPlan = withOrigin(ctx) { + AlterTableUnsetPropertiesCommand( + visitTableIdentifier(ctx.tableIdentifier), + visitPropertyKeys(ctx.tablePropertyList), + ctx.EXISTS != null, + ctx.VIEW != null) + } + /** * Create an [[AlterTableSerDePropertiesCommand]] command. * @@ -770,18 +821,17 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } /** - * Create an [[AlterTableSetLocationCommand]] command for a partition. + * Create an [[AlterTableSetLocationCommand]] command * * For example: * {{{ - * ALTER TABLE table PARTITION spec SET LOCATION "loc"; + * ALTER TABLE table [PARTITION spec] SET LOCATION "loc"; * }}} */ - override def visitSetPartitionLocation( - ctx: SetPartitionLocationContext): LogicalPlan = withOrigin(ctx) { + override def visitSetTableLocation(ctx: SetTableLocationContext): LogicalPlan = withOrigin(ctx) { AlterTableSetLocationCommand( visitTableIdentifier(ctx.tableIdentifier), - Some(visitNonOptionalPartitionSpec(ctx.partitionSpec)), + Option(ctx.partitionSpec).map(visitNonOptionalPartitionSpec), visitLocationSpec(ctx.locationSpec)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index c5f8cf24fe7d0..7d34b6568a4fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -28,12 +28,12 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.CastSupport import org.apache.spark.sql.catalyst.catalog.{BucketSpec, CatalogTable, CatalogTableType, CatalogUtils} import org.apache.spark.sql.catalyst.plans.logical.{CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan} -import org.apache.spark.sql.catalyst.plans.logical.sql.{AlterTableAddColumnsStatement, AlterTableSetLocationStatement, AlterTableSetPropertiesStatement, AlterTableUnsetPropertiesStatement, AlterViewSetPropertiesStatement, AlterViewUnsetPropertiesStatement, CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement, QualifiedColType} +import org.apache.spark.sql.catalyst.plans.logical.sql.{CreateTableAsSelectStatement, CreateTableStatement, DropTableStatement, DropViewStatement} import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.command.{AlterTableAddColumnsCommand, AlterTableSetLocationCommand, AlterTableSetPropertiesCommand, AlterTableUnsetPropertiesCommand, DropTableCommand} +import org.apache.spark.sql.execution.command.DropTableCommand import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.TableProvider -import org.apache.spark.sql.types.{HIVE_TYPE_STRING, HiveStringType, MetadataBuilder, StructField, StructType} +import org.apache.spark.sql.types.StructType case class DataSourceResolution( conf: SQLConf, @@ -96,26 +96,6 @@ case class DataSourceResolution( case DropViewStatement(AsTableIdentifier(tableName), ifExists) => DropTableCommand(tableName, ifExists, isView = true, purge = false) - - case AlterTableSetPropertiesStatement(AsTableIdentifier(table), properties) => - AlterTableSetPropertiesCommand(table, properties, isView = false) - - case AlterViewSetPropertiesStatement(AsTableIdentifier(table), properties) => - AlterTableSetPropertiesCommand(table, properties, isView = true) - - case AlterTableUnsetPropertiesStatement(AsTableIdentifier(table), propertyKeys, ifExists) => - AlterTableUnsetPropertiesCommand(table, propertyKeys, ifExists, isView = false) - - case AlterViewUnsetPropertiesStatement(AsTableIdentifier(table), propertyKeys, ifExists) => - AlterTableUnsetPropertiesCommand(table, propertyKeys, ifExists, isView = true) - - case AlterTableSetLocationStatement(AsTableIdentifier(table), newLocation) => - AlterTableSetLocationCommand(table, None, newLocation) - - case AlterTableAddColumnsStatement(AsTableIdentifier(table), newColumns) - if newColumns.forall(_.name.size == 1) => - // only top-level adds are supported using AlterTableAddColumnsCommand - AlterTableAddColumnsCommand(table, newColumns.map(convertToStructField)) } object V1WriteProvider { @@ -251,20 +231,4 @@ case class DataSourceResolution( tableProperties.toMap } - - private def convertToStructField(col: QualifiedColType): StructField = { - val builder = new MetadataBuilder - col.comment.foreach(builder.putString("comment", _)) - - val cleanedDataType = HiveStringType.replaceCharType(col.dataType) - if (col.dataType != cleanedDataType) { - builder.putString(HIVE_TYPE_STRING, col.dataType.catalogString) - } - - StructField( - col.name.head, - cleanedDataType, - nullable = true, - builder.build()) - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index 8203c900329c5..0dd11c1e518e0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -511,6 +511,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { assert(plan.newName == TableIdentifier("tbl2", Some("db1"))) } + // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); + // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); + test("alter table/view: alter table/view properties") { + val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + + "'comment' = 'new_comment')" + val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" + val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" + val sql1_view = sql1_table.replace("TABLE", "VIEW") + val sql2_view = sql2_table.replace("TABLE", "VIEW") + val sql3_view = sql3_table.replace("TABLE", "VIEW") + + val parsed1_table = parser.parsePlan(sql1_table) + val parsed2_table = parser.parsePlan(sql2_table) + val parsed3_table = parser.parsePlan(sql3_table) + val parsed1_view = parser.parsePlan(sql1_view) + val parsed2_view = parser.parsePlan(sql2_view) + val parsed3_view = parser.parsePlan(sql3_view) + + val tableIdent = TableIdentifier("table_name", None) + val expected1_table = AlterTableSetPropertiesCommand( + tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) + val expected2_table = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = false, isView = false) + val expected3_table = AlterTableUnsetPropertiesCommand( + tableIdent, Seq("comment", "test"), ifExists = true, isView = false) + val expected1_view = expected1_table.copy(isView = true) + val expected2_view = expected2_table.copy(isView = true) + val expected3_view = expected3_table.copy(isView = true) + + comparePlans(parsed1_table, expected1_table) + comparePlans(parsed2_table, expected2_table) + comparePlans(parsed3_table, expected3_table) + comparePlans(parsed1_view, expected1_view) + comparePlans(parsed2_view, expected2_view) + comparePlans(parsed3_view, expected3_view) + } + test("alter table - property values must be set") { assertUnsupported( sql = "ALTER TABLE my_tab SET TBLPROPERTIES('key_without_value', 'key_with_value'='x')", @@ -708,15 +747,22 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { "SET FILEFORMAT PARQUET") } - test("alter table: set partition location") { + test("alter table: set location") { + val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" val sql2 = "ALTER TABLE table_name PARTITION (dt='2008-08-08', country='us') " + "SET LOCATION 'new location'" + val parsed1 = parser.parsePlan(sql1) val parsed2 = parser.parsePlan(sql2) val tableIdent = TableIdentifier("table_name", None) + val expected1 = AlterTableSetLocationCommand( + tableIdent, + None, + "new location") val expected2 = AlterTableSetLocationCommand( tableIdent, Some(Map("dt" -> "2008-08-08", "country" -> "us")), "new location") + comparePlans(parsed1, expected1) comparePlans(parsed2, expected2) } @@ -900,6 +946,21 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { comparePlans(parsed, expected) } + test("support for other types in TBLPROPERTIES") { + val sql = + """ + |ALTER TABLE table_name + |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE) + """.stripMargin + val parsed = parser.parsePlan(sql) + val expected = AlterTableSetPropertiesCommand( + TableIdentifier("table_name"), + Map("a" -> "1", "b" -> "0.1", "c" -> "true"), + isView = false) + + comparePlans(parsed, expected) + } + test("Test CTAS #1") { val s1 = """ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index a834932110896..60801910c6dbc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -502,80 +502,4 @@ class PlanResolutionSuite extends AnalysisTest { }.getMessage.toLowerCase(Locale.ROOT).contains( "view support in catalog has not been implemented") } - - // ALTER VIEW view_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER VIEW view_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - test("alter view: alter view properties") { - val sql1_view = "ALTER VIEW table_name SET TBLPROPERTIES ('test' = 'test', " + - "'comment' = 'new_comment')" - val sql2_view = "ALTER VIEW table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3_view = "ALTER VIEW table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - - val parsed1_view = parseAndResolve(sql1_view) - val parsed2_view = parseAndResolve(sql2_view) - val parsed3_view = parseAndResolve(sql3_view) - - val tableIdent = TableIdentifier("table_name", None) - val expected1_view = AlterTableSetPropertiesCommand( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = true) - val expected2_view = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = false, isView = true) - val expected3_view = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = true, isView = true) - - comparePlans(parsed1_view, expected1_view) - comparePlans(parsed2_view, expected2_view) - comparePlans(parsed3_view, expected3_view) - } - - // ALTER TABLE table_name SET TBLPROPERTIES ('comment' = new_comment); - // ALTER TABLE table_name UNSET TBLPROPERTIES [IF EXISTS] ('comment', 'key'); - test("alter table: alter table properties") { - val sql1_table = "ALTER TABLE table_name SET TBLPROPERTIES ('test' = 'test', " + - "'comment' = 'new_comment')" - val sql2_table = "ALTER TABLE table_name UNSET TBLPROPERTIES ('comment', 'test')" - val sql3_table = "ALTER TABLE table_name UNSET TBLPROPERTIES IF EXISTS ('comment', 'test')" - - val parsed1_table = parseAndResolve(sql1_table) - val parsed2_table = parseAndResolve(sql2_table) - val parsed3_table = parseAndResolve(sql3_table) - - val tableIdent = TableIdentifier("table_name", None) - val expected1_table = AlterTableSetPropertiesCommand( - tableIdent, Map("test" -> "test", "comment" -> "new_comment"), isView = false) - val expected2_table = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = false, isView = false) - val expected3_table = AlterTableUnsetPropertiesCommand( - tableIdent, Seq("comment", "test"), ifExists = true, isView = false) - - comparePlans(parsed1_table, expected1_table) - comparePlans(parsed2_table, expected2_table) - comparePlans(parsed3_table, expected3_table) - } - - test("support for other types in TBLPROPERTIES") { - val sql = - """ - |ALTER TABLE table_name - |SET TBLPROPERTIES ('a' = 1, 'b' = 0.1, 'c' = TRUE) - """.stripMargin - val parsed = parseAndResolve(sql) - val expected = AlterTableSetPropertiesCommand( - TableIdentifier("table_name"), - Map("a" -> "1", "b" -> "0.1", "c" -> "true"), - isView = false) - - comparePlans(parsed, expected) - } - - test("alter table: set location") { - val sql1 = "ALTER TABLE table_name SET LOCATION 'new location'" - val parsed1 = parseAndResolve(sql1) - val tableIdent = TableIdentifier("table_name", None) - val expected1 = AlterTableSetLocationCommand( - tableIdent, - None, - "new location") - comparePlans(parsed1, expected1) - } } From 7c1eb922c614cd6186e0e8aab2f00da5941284b4 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 14:47:40 -0700 Subject: [PATCH 52/70] Revert "[SPARK-27103][SQL][MINOR] List SparkSql reserved keywords in alphabet order" This reverts commit d9e0cca491783a7af86d3f008f5e595d3d9e6cd0. --- .../spark/sql/catalyst/parser/SqlBase.g4 | 501 ++---------------- .../parser/TableIdentifierParserSuite.scala | 288 ++-------- 2 files changed, 71 insertions(+), 718 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 04fbdd2ddd15f..4133331c7fc40 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -788,485 +788,62 @@ number // // Let's say you add a new token `NEWTOKEN` and this is not reserved regardless of a `spark.sql.parser.ansi.enabled` // value. In this case, you must add a token `NEWTOKEN` in both `ansiNonReserved` and `nonReserved`. -// -// It is recommended to list them in alphabetical order. // The list of the reserved keywords when `spark.sql.parser.ansi.enabled` is true. Currently, we only reserve // the ANSI keywords that almost all the ANSI SQL standards (SQL-92, SQL-99, SQL-2003, SQL-2008, SQL-2011, // and SQL-2016) and PostgreSQL reserve. ansiReserved - : ALL - | AND - | ANTI - | ANY - | AS - | AUTHORIZATION - | BOTH - | CASE - | CAST - | CHECK - | COLLATE - | COLUMN - | CONSTRAINT - | CREATE - | CROSS - | CURRENT_DATE - | CURRENT_TIME - | CURRENT_TIMESTAMP - | CURRENT_USER - | DISTINCT - | ELSE - | END - | EXCEPT - | FALSE - | FETCH - | FOR - | FOREIGN - | FROM - | FULL - | GRANT - | GROUP - | HAVING - | IN - | INNER - | INTERSECT - | INTO - | IS - | JOIN - | LEADING - | LEFT - | NATURAL - | NOT - | NULL - | ON - | ONLY - | OR - | ORDER - | OUTER - | OVERLAPS - | PRIMARY - | REFERENCES - | RIGHT - | SELECT - | SEMI - | SESSION_USER - | SETMINUS - | SOME - | TABLE - | THEN - | TO - | TRAILING - | UNION - | UNIQUE - | USER - | USING - | WHEN - | WHERE - | WITH + : ALL | AND | ANTI | ANY | AS | AUTHORIZATION | BOTH | CASE | CAST | CHECK | COLLATE | COLUMN | CONSTRAINT | CREATE + | CROSS | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | CURRENT_USER | DISTINCT | ELSE | END | EXCEPT | FALSE + | FETCH | FOR | FOREIGN | FROM | FULL | GRANT | GROUP | HAVING | IN | INNER | INTERSECT | INTO | JOIN | IS + | LEADING | LEFT | NATURAL | NOT | NULL | ON | ONLY | OR | ORDER | OUTER | OVERLAPS | PRIMARY | REFERENCES | RIGHT + | SELECT | SEMI | SESSION_USER | SETMINUS | SOME | TABLE | THEN | TO | TRAILING | UNION | UNIQUE | USER | USING + | WHEN | WHERE | WITH ; // The list of the non-reserved keywords when `spark.sql.parser.ansi.enabled` is true. ansiNonReserved - : ADD - | AFTER - | ALTER - | ANALYZE - | ARCHIVE - | ARRAY - | ASC - | AT - | BETWEEN - | BUCKET - | BUCKETS - | BY - | CACHE - | CASCADE - | CHANGE - | CLEAR - | CLUSTER - | CLUSTERED - | CODEGEN - | COLLECTION - | COLUMNS - | COMMENT - | COMMIT - | COMPACT - | COMPACTIONS - | COMPUTE - | CONCATENATE - | COST - | CUBE - | CURRENT - | DATA - | DATABASE - | DATABASES - | DBPROPERTIES - | DEFINED - | DELETE - | DELIMITED - | DESC - | DESCRIBE - | DFS - | DIRECTORIES - | DIRECTORY - | DISTRIBUTE - | DIV - | DROP - | ESCAPED - | EXCHANGE - | EXISTS - | EXPLAIN - | EXPORT - | EXTENDED - | EXTERNAL - | EXTRACT - | FIELDS - | FILEFORMAT - | FIRST - | FOLLOWING - | FORMAT - | FORMATTED - | FUNCTION - | FUNCTIONS - | GLOBAL - | GROUPING - | IF - | IGNORE - | IMPORT - | INDEX - | INDEXES - | INPATH - | INPUTFORMAT - | INSERT - | INTERVAL - | ITEMS - | KEYS - | LAST - | LATERAL - | LAZY - | LIKE - | LIMIT - | LINES - | LIST - | LOAD - | LOCAL - | LOCATION - | LOCK - | LOCKS - | LOGICAL - | MACRO - | MAP - | MSCK - | NO - | NULLS - | OF - | OPTION - | OPTIONS - | OUT - | OUTPUTFORMAT - | OVER - | OVERWRITE - | PARTITION - | PARTITIONED - | PARTITIONS - | PERCENT - | PERCENTLIT - | PIVOT - | PRECEDING - | PRINCIPALS - | PURGE - | QUERY - | RANGE - | RECORDREADER - | RECORDWRITER - | RECOVER - | REDUCE - | REFRESH - | RENAME - | REPAIR - | REPLACE - | RESET - | RESTRICT - | REVOKE - | RLIKE - | ROLE - | ROLES - | ROLLBACK - | ROLLUP - | ROW - | ROWS - | SCHEMA - | SEPARATED - | SERDE - | SERDEPROPERTIES - | SET - | SETS - | SHOW - | SKEWED - | SORT - | SORTED - | START - | STATISTICS - | STORED - | STRATIFY - | STRUCT - | TABLES - | TABLESAMPLE - | TBLPROPERTIES - | TEMPORARY - | TERMINATED - | TOUCH - | TRANSACTION - | TRANSACTIONS - | TRANSFORM - | TRUE - | TRUNCATE - | UNARCHIVE - | UNBOUNDED - | UNCACHE - | UNLOCK - | UNSET - | USE - | VALUES - | VIEW - | WINDOW + : ADD | AFTER | ALTER | ANALYZE | ARCHIVE | ARRAY | ASC | AT | BETWEEN | BUCKET | BUCKETS | BY | CACHE | CASCADE + | CHANGE | CLEAR | CLUSTER | CLUSTERED | CODEGEN | COLLECTION | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS + | COMPUTE | CONCATENATE | COST | CUBE | CURRENT | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE + | DELIMITED | DESC | DESCRIBE | DFS | DIRECTORIES | DIRECTORY | DISTRIBUTE | DIV | DROP | ESCAPED | EXCHANGE + | EXISTS | EXPLAIN | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FORMAT + | FORMATTED | FUNCTION | FUNCTIONS | GLOBAL | GROUPING | IF | IGNORE | IMPORT | INDEX | INDEXES | INPATH + | INPUTFORMAT | INSERT | INTERVAL | ITEMS | KEYS | LAST | LATERAL | LAZY | LIKE | LIMIT | LINES | LIST | LOAD + | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO | MAP | MSCK | NO | NULLS | OF | OPTION | OPTIONS | OUT + | OUTPUTFORMAT | OVER | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENT | PERCENTLIT | PIVOT | PRECEDING + | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFRESH | RENAME | REPAIR | REPLACE + | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS | SCHEMA | SEPARATED | SERDE + | SERDEPROPERTIES | SET | SETS | SHOW | SKEWED | SORT | SORTED | START | STATISTICS | STORED | STRATIFY | STRUCT + | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY | TERMINATED | TOUCH | TRANSACTION | TRANSACTIONS | TRANSFORM + | TRUE | TRUNCATE | UNARCHIVE | UNBOUNDED | UNCACHE | UNLOCK | UNSET | USE | VALUES | VIEW | WINDOW ; defaultReserved - : ANTI - | CROSS - | EXCEPT - | FULL - | INNER - | INTERSECT - | JOIN - | LEFT - | NATURAL - | ON - | RIGHT - | SEMI - | SETMINUS - | UNION + : ANTI | CROSS | EXCEPT | FULL | INNER | INTERSECT | JOIN | LEFT | NATURAL | ON | RIGHT | SEMI | SETMINUS | UNION | USING ; nonReserved - : ADD - | AFTER - | ALL - | ALTER - | ANALYZE - | AND - | ANY - | ARCHIVE - | ARRAY - | AS - | ASC - | AT - | AUTHORIZATION - | BETWEEN - | BOTH - | BUCKET - | BUCKETS - | BY - | CACHE - | CASCADE - | CASE - | CAST - | CHANGE - | CHECK - | CLEAR - | CLUSTER - | CLUSTERED - | CODEGEN - | COLLATE - | COLLECTION - | COLUMN - | COLUMNS - | COMMENT - | COMMIT - | COMPACT - | COMPACTIONS - | COMPUTE - | CONCATENATE - | CONSTRAINT - | COST - | CREATE - | CUBE - | CURRENT - | CURRENT_DATE - | CURRENT_TIME - | CURRENT_TIMESTAMP - | CURRENT_USER - | DATA - | DATABASE - | DATABASES - | DBPROPERTIES - | DEFINED - | DELETE - | DELIMITED - | DESC - | DESCRIBE - | DFS - | DIRECTORIES - | DIRECTORY - | DISTINCT - | DISTRIBUTE - | DIV - | DROP - | ELSE - | END - | ESCAPED - | EXCHANGE - | EXISTS - | EXPLAIN - | EXPORT - | EXTENDED - | EXTERNAL - | EXTRACT - | FALSE - | FETCH - | FIELDS - | FILEFORMAT - | FIRST - | FOLLOWING - | FOR - | FOREIGN - | FORMAT - | FORMATTED - | FROM - | FUNCTION - | FUNCTIONS - | GLOBAL - | GRANT - | GROUP - | GROUPING - | HAVING - | IF - | IGNORE - | IMPORT - | IN - | INDEX - | INDEXES - | INPATH - | INPUTFORMAT - | INSERT - | INTERVAL - | INTO - | IS - | ITEMS - | KEYS - | LAST - | LATERAL - | LAZY - | LEADING - | LIKE - | LIMIT - | LINES - | LIST - | LOAD - | LOCAL - | LOCATION - | LOCK - | LOCKS - | LOGICAL - | MACRO - | MAP - | MSCK - | NO - | NOT - | NULL - | NULLS - | OF - | ONLY - | OPTION - | OPTIONS - | OR - | ORDER - | OUT - | OUTER - | OUTPUTFORMAT - | OVER - | OVERLAPS - | OVERWRITE - | PARTITION - | PARTITIONED - | PARTITIONS - | PERCENTLIT - | PIVOT - | POSITION - | PRECEDING - | PRIMARY - | PRINCIPALS - | PURGE - | QUERY - | RANGE - | RECORDREADER - | RECORDWRITER - | RECOVER - | REDUCE - | REFERENCES - | REFRESH - | RENAME - | REPAIR - | REPLACE - | RESET - | RESTRICT - | REVOKE - | RLIKE - | ROLE - | ROLES - | ROLLBACK - | ROLLUP - | ROW - | ROWS - | SELECT - | SEPARATED - | SERDE - | SERDEPROPERTIES - | SESSION_USER - | SET - | SETS - | SHOW - | SKEWED - | SOME - | SORT - | SORTED - | START - | STATISTICS - | STORED - | STRATIFY - | STRUCT - | TABLE - | TABLES - | TABLESAMPLE - | TBLPROPERTIES - | TEMPORARY - | TERMINATED - | THEN - | TO - | TOUCH - | TRAILING - | TRANSACTION - | TRANSACTIONS - | TRANSFORM - | TRUE - | TRUNCATE - | UNARCHIVE - | UNBOUNDED - | UNCACHE - | UNIQUE - | UNLOCK - | UNSET - | USE - | USER - | VALUES - | VIEW - | WHEN - | WHERE - | WINDOW - | WITH + : ADD | AFTER | ALL | ALTER | ANALYZE | AND | ANY | ARCHIVE | ARRAY | AS | ASC | AT | AUTHORIZATION | BETWEEN + | BOTH | BUCKET | BUCKETS | BY | CACHE | CASCADE | CASE | CAST | CHANGE | CHECK | CLEAR | CLUSTER | CLUSTERED + | CODEGEN | COLLATE | COLLECTION | COLUMN | COLUMNS | COMMENT | COMMIT | COMPACT | COMPACTIONS | COMPUTE + | CONCATENATE | CONSTRAINT | COST | CREATE | CUBE | CURRENT | CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP + | CURRENT_USER | DATA | DATABASE | DATABASES | DBPROPERTIES | DEFINED | DELETE | DELIMITED | DESC | DESCRIBE | DFS + | DIRECTORIES | DIRECTORY | DISTINCT | DISTRIBUTE | DIV | DROP | ELSE | END | ESCAPED | EXCHANGE | EXISTS | EXPLAIN + | EXPORT | EXTENDED | EXTERNAL | EXTRACT | FALSE | FETCH | FIELDS | FILEFORMAT | FIRST | FOLLOWING | FOR | FOREIGN + | FORMAT | FORMATTED | FROM | FUNCTION | FUNCTIONS | GLOBAL | GRANT | GROUP | GROUPING | HAVING | IF | IGNORE + | IMPORT | IN | INDEX | INDEXES | INPATH | INPUTFORMAT | INSERT | INTERVAL | INTO | IS | ITEMS | KEYS | LAST + | LATERAL | LAZY | LEADING | LIKE | LIMIT | LINES | LIST | LOAD | LOCAL | LOCATION | LOCK | LOCKS | LOGICAL | MACRO + | MAP | MSCK | NO | NOT | NULL | NULLS | OF | ONLY | OPTION | OPTIONS | OR | ORDER | OUT | OUTER | OUTPUTFORMAT + | OVER | OVERLAPS | OVERWRITE | PARTITION | PARTITIONED | PARTITIONS | PERCENTLIT | PIVOT | POSITION | PRECEDING + | PRIMARY | PRINCIPALS | PURGE | QUERY | RANGE | RECORDREADER | RECORDWRITER | RECOVER | REDUCE | REFERENCES | REFRESH + | RENAME | REPAIR | REPLACE | RESET | RESTRICT | REVOKE | RLIKE | ROLE | ROLES | ROLLBACK | ROLLUP | ROW | ROWS + | SELECT | SEPARATED | SERDE | SERDEPROPERTIES | SESSION_USER | SET | SETS | SHOW | SKEWED | SOME | SORT | SORTED + | START | STATISTICS | STORED | STRATIFY | STRUCT | TABLE | TABLES | TABLESAMPLE | TBLPROPERTIES | TEMPORARY + | TERMINATED | THEN | TO | TOUCH | TRAILING | TRANSACTION | TRANSACTIONS | TRANSFORM | TRUE | TRUNCATE | UNARCHIVE + | UNBOUNDED | UNCACHE | UNLOCK | UNIQUE | UNSET | USE | USER | VALUES | VIEW | WHEN | WHERE | WINDOW | WITH ; SELECT: 'SELECT'; diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala index 3d41c27f217d9..489b7f328f8fa 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/TableIdentifierParserSuite.scala @@ -23,263 +23,39 @@ class TableIdentifierParserSuite extends SparkFunSuite { import CatalystSqlParser._ // Add "$elem$", "$value$" & "$key$" - // It is recommended to list them in alphabetical order. - val hiveNonReservedKeyword = Array( - "add", - "admin", - "after", - "all", - "alter", - "analyze", - "any", - "archive", - "array", - "as", - "asc", - "at", - "authorization", - "autocommit", - "before", - "between", - "bigint", - "binary", - "boolean", - "both", - "bucket", - "buckets", - "by", - "cascade", - "change", - "cluster", - "clustered", - "clusterstatus", - "collection", - "columns", - "comment", - "compact", - "compactions", - "compute", - "concatenate", - "continue", - "cost", - "create", - "cube", - "current_date", - "current_timestamp", - "cursor", - "data", - "databases", - "date", - "datetime", - "day", - "dbproperties", - "decimal", - "deferred", - "defined", - "delete", - "delimited", - "dependency", - "desc", - "describe", - "directories", - "directory", - "disable", - "distribute", - "double", - "drop", - "enable", - "escaped", - "exclusive", - "exists", - "explain", - "export", - "external", - "extract", - "false", - "fetch", - "fields", - "file", - "fileformat", - "first", - "float", - "for", - "format", - "formatted", - "functions", - "grant", - "group", - "grouping", - "hold_ddltime", - "hour", - "idxproperties", - "ignore", - "import", - "in", - "index", - "indexes", - "inpath", - "inputdriver", - "inputformat", - "insert", - "int", - "into", - "is", - "isolation", - "items", - "jar", - "key_type", - "keys", - "last", - "lateral", - "leading", - "level", - "like", - "limit", - "lines", - "load", - "local", - "location", - "lock", - "locks", - "logical", - "long", - "mapjoin", - "materialized", - "metadata", - "minus", - "minute", - "month", - "msck", - "no_drop", - "none", - "noscan", - "null", - "nulls", - "of", - "offline", - "offset", - "option", - "order", - "out", - "outer", - "outputdriver", - "outputformat", - "overwrite", - "owner", - "partition", - "partitioned", - "partitions", - "percent", - "pivot", - "plus", - "position", - "pretty", - "principals", - "procedure", - "protection", - "purge", - "query", - "range", - "read", - "readonly", - "reads", - "rebuild", - "recordreader", - "recordwriter", - "regexp", - "reload", - "rename", - "repair", - "replace", - "replication", - "restrict", - "revoke", - "rewrite", - "rlike", - "role", - "roles", - "rollup", - "row", - "rows", - "schemas", - "second", - "serde", - "serdeproperties", - "server", - "set", - "sets", - "shared", - "show", - "show_database", - "skewed", - "smallint", - "snapshot", - "sort", - "sorted", - "ssl", - "statistics", - "stored", - "streamtable", - "string", - "struct", - "table", - "tables", - "tblproperties", - "temporary", - "terminated", - "timestamp", - "tinyint", - "to", - "touch", - "trailing", - "transaction", - "transactions", - "trigger", - "true", - "truncate", - "unarchive", - "undo", - "uniontype", - "unlock", - "unset", - "unsigned", - "update", - "uri", - "use", - "user", - "utc", - "utctimestamp", - "values", - "view", - "while", - "with", - "work", - "write", - "year") + val hiveNonReservedKeyword = Array("add", "admin", "after", "analyze", "archive", "asc", "before", + "bucket", "buckets", "cascade", "change", "cluster", "clustered", "clusterstatus", "collection", + "columns", "comment", "compact", "compactions", "compute", "concatenate", "continue", "cost", + "data", "day", "databases", "datetime", "dbproperties", "deferred", "defined", "delimited", + "dependency", "desc", "directories", "directory", "disable", "distribute", + "enable", "escaped", "exclusive", "explain", "export", "fields", "file", "fileformat", "first", + "format", "formatted", "functions", "hold_ddltime", "hour", "idxproperties", "ignore", "index", + "indexes", "inpath", "inputdriver", "inputformat", "items", "jar", "keys", "key_type", "last", + "limit", "offset", "lines", "load", "location", "lock", "locks", "logical", "long", "mapjoin", + "materialized", "metadata", "minus", "minute", "month", "msck", "noscan", "no_drop", "nulls", + "offline", "option", "outputdriver", "outputformat", "overwrite", "owner", "partitioned", + "partitions", "plus", "pretty", "principals", "protection", "purge", "read", "readonly", + "rebuild", "recordreader", "recordwriter", "reload", "rename", "repair", "replace", + "replication", "restrict", "rewrite", "role", "roles", "schemas", "second", + "serde", "serdeproperties", "server", "sets", "shared", "show", "show_database", "skewed", + "sort", "sorted", "ssl", "statistics", "stored", "streamtable", "string", "struct", "tables", + "tblproperties", "temporary", "terminated", "tinyint", "touch", "transactions", "unarchive", + "undo", "uniontype", "unlock", "unset", "unsigned", "uri", "use", "utc", "utctimestamp", + "view", "while", "year", "work", "transaction", "write", "isolation", "level", "snapshot", + "autocommit", "all", "any", "alter", "array", "as", "authorization", "between", "bigint", + "binary", "boolean", "both", "by", "create", "cube", "current_date", "current_timestamp", + "cursor", "date", "decimal", "delete", "describe", "double", "drop", "exists", "external", + "false", "fetch", "float", "for", "grant", "group", "grouping", "import", "in", + "insert", "int", "into", "is", "pivot", "lateral", "like", "local", "none", "null", + "of", "order", "out", "outer", "partition", "percent", "procedure", "query", "range", "reads", + "revoke", "rollup", "row", "rows", "set", "smallint", "table", "timestamp", "to", "trigger", + "true", "truncate", "update", "user", "values", "with", "regexp", "rlike", + "bigint", "binary", "boolean", "current_date", "current_timestamp", "date", "double", "float", + "int", "smallint", "timestamp", "at", "position", "both", "leading", "trailing", "extract") - val hiveStrictNonReservedKeyword = Seq( - "anti", - "cross", - "database", - "except", - "from", - "full", - "having", - "inner", - "intersect", - "join", - "left", - "natural", - "not", - "on", - "right", - "select", - "semi", - "table", - "to", - "union", - "where", - "with") + val hiveStrictNonReservedKeyword = Seq("anti", "full", "inner", "left", "semi", "right", + "natural", "union", "intersect", "except", "database", "on", "join", "cross", "select", "from", + "where", "having", "from", "to", "table", "with", "not") test("table identifier") { // Regular names. From cfa37b026a4f2a313ec7113e3ab529261fda4a4e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Sun, 12 May 2019 19:59:56 +0900 Subject: [PATCH 53/70] [SPARK-27675][SQL] do not use MutableColumnarRow in ColumnarBatch ## What changes were proposed in this pull request? To move DS v2 API to the catalyst module, we can't refer to an internal class (`MutableColumnarRow`) in `ColumnarBatch`. This PR creates a read-only version of `MutableColumnarRow`, and use it in `ColumnarBatch`. close https://github.com/apache/spark/pull/24546 ## How was this patch tested? existing tests Closes #24581 from cloud-fan/mutable-row. Authored-by: Wenchen Fan Signed-off-by: HyukjinKwon --- .../spark/sql/vectorized/ColumnarBatch.java | 175 +++++++++++++++++- .../vectorized/MutableColumnarRow.java | 44 ++--- 2 files changed, 189 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index 07546a54013ec..9f917ea11d72a 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -20,7 +20,10 @@ import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.catalyst.InternalRow; -import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; +import org.apache.spark.sql.types.*; +import org.apache.spark.unsafe.types.CalendarInterval; +import org.apache.spark.unsafe.types.UTF8String; /** * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this @@ -33,7 +36,7 @@ public final class ColumnarBatch { private final ColumnVector[] columns; // Staging row returned from `getRow`. - private final MutableColumnarRow row; + private final ColumnarBatchRow row; /** * Called to close all the columns in this batch. It is not valid to access the data after @@ -50,7 +53,7 @@ public void close() { */ public Iterator rowIterator() { final int maxRows = numRows; - final MutableColumnarRow row = new MutableColumnarRow(columns); + final ColumnarBatchRow row = new ColumnarBatchRow(columns); return new Iterator() { int rowId = 0; @@ -108,6 +111,170 @@ public InternalRow getRow(int rowId) { public ColumnarBatch(ColumnVector[] columns) { this.columns = columns; - this.row = new MutableColumnarRow(columns); + this.row = new ColumnarBatchRow(columns); } } + +/** + * An internal class, which wraps an array of {@link ColumnVector} and provides a row view. + */ +class ColumnarBatchRow extends InternalRow { + public int rowId; + private final ColumnVector[] columns; + + ColumnarBatchRow(ColumnVector[] columns) { + this.columns = columns; + } + + @Override + public int numFields() { return columns.length; } + + @Override + public InternalRow copy() { + GenericInternalRow row = new GenericInternalRow(columns.length); + for (int i = 0; i < numFields(); i++) { + if (isNullAt(i)) { + row.setNullAt(i); + } else { + DataType dt = columns[i].dataType(); + if (dt instanceof BooleanType) { + row.setBoolean(i, getBoolean(i)); + } else if (dt instanceof ByteType) { + row.setByte(i, getByte(i)); + } else if (dt instanceof ShortType) { + row.setShort(i, getShort(i)); + } else if (dt instanceof IntegerType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof LongType) { + row.setLong(i, getLong(i)); + } else if (dt instanceof FloatType) { + row.setFloat(i, getFloat(i)); + } else if (dt instanceof DoubleType) { + row.setDouble(i, getDouble(i)); + } else if (dt instanceof StringType) { + row.update(i, getUTF8String(i).copy()); + } else if (dt instanceof BinaryType) { + row.update(i, getBinary(i)); + } else if (dt instanceof DecimalType) { + DecimalType t = (DecimalType)dt; + row.setDecimal(i, getDecimal(i, t.precision(), t.scale()), t.precision()); + } else if (dt instanceof DateType) { + row.setInt(i, getInt(i)); + } else if (dt instanceof TimestampType) { + row.setLong(i, getLong(i)); + } else { + throw new RuntimeException("Not implemented. " + dt); + } + } + } + return row; + } + + @Override + public boolean anyNull() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isNullAt(int ordinal) { return columns[ordinal].isNullAt(rowId); } + + @Override + public boolean getBoolean(int ordinal) { return columns[ordinal].getBoolean(rowId); } + + @Override + public byte getByte(int ordinal) { return columns[ordinal].getByte(rowId); } + + @Override + public short getShort(int ordinal) { return columns[ordinal].getShort(rowId); } + + @Override + public int getInt(int ordinal) { return columns[ordinal].getInt(rowId); } + + @Override + public long getLong(int ordinal) { return columns[ordinal].getLong(rowId); } + + @Override + public float getFloat(int ordinal) { return columns[ordinal].getFloat(rowId); } + + @Override + public double getDouble(int ordinal) { return columns[ordinal].getDouble(rowId); } + + @Override + public Decimal getDecimal(int ordinal, int precision, int scale) { + return columns[ordinal].getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int ordinal) { + return columns[ordinal].getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int ordinal) { + return columns[ordinal].getBinary(rowId); + } + + @Override + public CalendarInterval getInterval(int ordinal) { + return columns[ordinal].getInterval(rowId); + } + + @Override + public ColumnarRow getStruct(int ordinal, int numFields) { + return columns[ordinal].getStruct(rowId); + } + + @Override + public ColumnarArray getArray(int ordinal) { + return columns[ordinal].getArray(rowId); + } + + @Override + public ColumnarMap getMap(int ordinal) { + return columns[ordinal].getMap(rowId); + } + + @Override + public Object get(int ordinal, DataType dataType) { + if (dataType instanceof BooleanType) { + return getBoolean(ordinal); + } else if (dataType instanceof ByteType) { + return getByte(ordinal); + } else if (dataType instanceof ShortType) { + return getShort(ordinal); + } else if (dataType instanceof IntegerType) { + return getInt(ordinal); + } else if (dataType instanceof LongType) { + return getLong(ordinal); + } else if (dataType instanceof FloatType) { + return getFloat(ordinal); + } else if (dataType instanceof DoubleType) { + return getDouble(ordinal); + } else if (dataType instanceof StringType) { + return getUTF8String(ordinal); + } else if (dataType instanceof BinaryType) { + return getBinary(ordinal); + } else if (dataType instanceof DecimalType) { + DecimalType t = (DecimalType) dataType; + return getDecimal(ordinal, t.precision(), t.scale()); + } else if (dataType instanceof DateType) { + return getInt(ordinal); + } else if (dataType instanceof TimestampType) { + return getLong(ordinal); + } else if (dataType instanceof ArrayType) { + return getArray(ordinal); + } else if (dataType instanceof StructType) { + return getStruct(ordinal, ((StructType)dataType).fields().length); + } else if (dataType instanceof MapType) { + return getMap(ordinal); + } else { + throw new UnsupportedOperationException("Datatype not supported " + dataType); + } + } + + @Override + public void update(int ordinal, Object value) { throw new UnsupportedOperationException(); } + + @Override + public void setNullAt(int ordinal) { throw new UnsupportedOperationException(); } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 4e4242fe8d9b9..fca7e36859126 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -26,7 +26,6 @@ import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.vectorized.ColumnarMap; import org.apache.spark.sql.vectorized.ColumnarRow; -import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -39,17 +38,10 @@ */ public final class MutableColumnarRow extends InternalRow { public int rowId; - private final ColumnVector[] columns; - private final WritableColumnVector[] writableColumns; - - public MutableColumnarRow(ColumnVector[] columns) { - this.columns = columns; - this.writableColumns = null; - } + private final WritableColumnVector[] columns; public MutableColumnarRow(WritableColumnVector[] writableColumns) { this.columns = writableColumns; - this.writableColumns = writableColumns; } @Override @@ -228,54 +220,54 @@ public void update(int ordinal, Object value) { @Override public void setNullAt(int ordinal) { - writableColumns[ordinal].putNull(rowId); + columns[ordinal].putNull(rowId); } @Override public void setBoolean(int ordinal, boolean value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putBoolean(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putBoolean(rowId, value); } @Override public void setByte(int ordinal, byte value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putByte(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putByte(rowId, value); } @Override public void setShort(int ordinal, short value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putShort(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putShort(rowId, value); } @Override public void setInt(int ordinal, int value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putInt(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putInt(rowId, value); } @Override public void setLong(int ordinal, long value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putLong(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putLong(rowId, value); } @Override public void setFloat(int ordinal, float value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putFloat(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putFloat(rowId, value); } @Override public void setDouble(int ordinal, double value) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putDouble(rowId, value); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDouble(rowId, value); } @Override public void setDecimal(int ordinal, Decimal value, int precision) { - writableColumns[ordinal].putNotNull(rowId); - writableColumns[ordinal].putDecimal(rowId, value, precision); + columns[ordinal].putNotNull(rowId); + columns[ordinal].putDecimal(rowId, value, precision); } } From d8f503ebcb835164c76145f9f7ee81772917408c Mon Sep 17 00:00:00 2001 From: Xianyang Liu Date: Thu, 28 Mar 2019 12:11:00 -0500 Subject: [PATCH 54/70] [MINOR] Move java file to java directory ## What changes were proposed in this pull request? move ```scala org.apache.spark.sql.execution.streaming.BaseStreamingSource org.apache.spark.sql.execution.streaming.BaseStreamingSink ``` to java directory ## How was this patch tested? Existing UT. Closes #24222 from ConeyLiu/move-scala-to-java. Authored-by: Xianyang Liu Signed-off-by: Sean Owen --- .../apache/spark/sql/execution/streaming/BaseStreamingSink.java | 0 .../apache/spark/sql/execution/streaming/BaseStreamingSource.java | 0 .../org/apache/spark/sql/execution/streaming/Offset.java | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename sql/core/src/main/{scala => java}/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java (100%) rename sql/core/src/main/{scala => java}/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java (100%) rename sql/core/src/main/{scala => java}/org/apache/spark/sql/execution/streaming/Offset.java (100%) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java similarity index 100% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Offset.java rename to sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java From b28de534aa764620e03597f25faa59881c9bb014 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Fri, 26 Apr 2019 15:44:23 +0800 Subject: [PATCH 55/70] [SPARK-27190][SQL] add table capability for streaming This is a followup of https://github.com/apache/spark/pull/24012 , to add the corresponding capabilities for streaming. existing tests Closes #24129 from cloud-fan/capability. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/kafka010/KafkaSourceProvider.scala | 11 +- .../spark/sql/sources/v2/TableCapability.java | 19 +++ .../spark/sql/sources/v2/reader/Scan.java | 10 +- .../sources/v2/SupportsContinuousRead.java | 35 ----- .../sources/v2/SupportsMicroBatchRead.java | 35 ----- .../sources/v2/SupportsStreamingWrite.java | 34 ----- .../datasources/noop/NoopDataSource.scala | 11 +- .../v2/V2StreamingScanSupportCheck.scala | 64 +++++++++ .../streaming/MicroBatchExecution.scala | 61 ++++---- .../execution/streaming/StreamExecution.scala | 4 +- .../sql/execution/streaming/console.scala | 9 +- .../continuous/ContinuousExecution.scala | 22 +-- .../sql/execution/streaming/memory.scala | 9 +- .../sources/ForeachWriterTable.scala | 12 +- .../sources/RateStreamProvider.scala | 9 +- .../sources/TextSocketSourceProvider.scala | 9 +- .../streaming/sources/memoryV2.scala | 10 +- .../internal/BaseSessionStateBuilder.scala | 3 +- .../sql/streaming/DataStreamReader.scala | 5 +- .../sql/streaming/DataStreamWriter.scala | 7 +- .../sql/streaming/StreamingQueryManager.scala | 6 +- .../v2/V2StreamingScanSupportCheckSuite.scala | 130 ++++++++++++++++++ .../sources/StreamingDataSourceV2Suite.scala | 106 ++++++++------ .../sql/hive/HiveSessionStateBuilder.scala | 3 +- 24 files changed, 394 insertions(+), 230 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 4af263e1a7f2e..88363d33525ee 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.kafka010 import java.{util => ju} -import java.util.{Collections, Locale, UUID} +import java.util.{Locale, UUID} import scala.collection.JavaConverters._ @@ -28,9 +28,10 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{Sink, Source} +import org.apache.spark.sql.execution.streaming.{BaseStreamingSink, Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader.{Scan, ScanBuilder} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream} import org.apache.spark.sql.sources.v2.writer.WriteBuilder @@ -352,13 +353,15 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead with SupportsStreamingWrite { + with SupportsRead with SupportsWrite with BaseStreamingSink { override def name(): String = s"Kafka $strategy" override def schema(): StructType = KafkaOffsetReader.kafkaSchema - override def capabilities(): ju.Set[TableCapability] = Collections.emptySet() + override def capabilities(): ju.Set[TableCapability] = { + Set(MICRO_BATCH_READ, CONTINUOUS_READ, STREAMING_WRITE).asJava + } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new KafkaScan(options) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java index 33c3d647bf409..c44a12b174f4c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/TableCapability.java @@ -33,6 +33,16 @@ public enum TableCapability { */ BATCH_READ, + /** + * Signals that the table supports reads in micro-batch streaming execution mode. + */ + MICRO_BATCH_READ, + + /** + * Signals that the table supports reads in continuous streaming execution mode. + */ + CONTINUOUS_READ, + /** * Signals that the table supports append writes in batch execution mode. *

    @@ -42,6 +52,15 @@ public enum TableCapability { */ BATCH_WRITE, + /** + * Signals that the table supports append writes in streaming execution mode. + *

    + * Tables that return this capability must support appending data and may also support additional + * write modes, like {@link #TRUNCATE}, {@link #OVERWRITE_BY_FILTER}, and + * {@link #OVERWRITE_DYNAMIC}. + */ + STREAMING_WRITE, + /** * Signals that the table can be truncated in a write operation. *

    diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java index 7633d504d36b1..ac4f38287a24d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/Scan.java @@ -21,8 +21,6 @@ import org.apache.spark.sql.sources.v2.reader.streaming.ContinuousStream; import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream; import org.apache.spark.sql.types.StructType; -import org.apache.spark.sql.sources.v2.SupportsContinuousRead; -import org.apache.spark.sql.sources.v2.SupportsMicroBatchRead; import org.apache.spark.sql.sources.v2.Table; import org.apache.spark.sql.sources.v2.TableCapability; @@ -74,8 +72,8 @@ default Batch toBatch() { /** * Returns the physical representation of this scan for streaming query with micro-batch mode. By * default this method throws exception, data sources must overwrite this method to provide an - * implementation, if the {@link Table} that creates this scan implements - * {@link SupportsMicroBatchRead}. + * implementation, if the {@link Table} that creates this scan returns + * {@link TableCapability#MICRO_BATCH_READ} support in its {@link Table#capabilities()}. * * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure * recovery. Data streams for the same logical source in the same query @@ -90,8 +88,8 @@ default MicroBatchStream toMicroBatchStream(String checkpointLocation) { /** * Returns the physical representation of this scan for streaming query with continuous mode. By * default this method throws exception, data sources must overwrite this method to provide an - * implementation, if the {@link Table} that creates this scan implements - * {@link SupportsContinuousRead}. + * implementation, if the {@link Table} that creates this scan returns + * {@link TableCapability#CONTINUOUS_READ} support in its {@link Table#capabilities()}. * * @param checkpointLocation a path to Hadoop FS scratch space that can be used for failure * recovery. Data streams for the same logical source in the same query diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java deleted file mode 100644 index 5cc9848d9da89..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsContinuousRead.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with - * continuous mode. - *

    - * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} - * that builds {@link Scan} with {@link Scan#toContinuousStream(String)} implemented. - *

    - */ -@Evolving -public interface SupportsContinuousRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java deleted file mode 100644 index c98f3f1aa5cba..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsMicroBatchRead.java +++ /dev/null @@ -1,35 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.sources.v2.reader.Scan; -import org.apache.spark.sql.sources.v2.reader.ScanBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports streaming scan with - * micro-batch mode. - *

    - * If a {@link Table} implements this interface, the - * {@link SupportsRead#newScanBuilder(CaseInsensitiveStringMap)} must return a {@link ScanBuilder} - * that builds {@link Scan} with {@link Scan#toMicroBatchStream(String)} implemented. - *

    - */ -@Evolving -public interface SupportsMicroBatchRead extends SupportsRead { } diff --git a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java deleted file mode 100644 index ac11e483c18c4..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/SupportsStreamingWrite.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.sources.v2; - -import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSink; -import org.apache.spark.sql.sources.v2.writer.WriteBuilder; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; - -/** - * An empty mix-in interface for {@link Table}, to indicate this table supports streaming write. - *

    - * If a {@link Table} implements this interface, the - * {@link SupportsWrite#newWriteBuilder(CaseInsensitiveStringMap)} must return a - * {@link WriteBuilder} with {@link WriteBuilder#buildForStreaming()} implemented. - *

    - */ -@Evolving -public interface SupportsStreamingWrite extends SupportsWrite, BaseStreamingSink { } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 1da41f2baefcb..296a796890c77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,6 +22,7 @@ import java.util import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.BaseStreamingSink import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ @@ -38,14 +39,22 @@ class NoopDataSource extends TableProvider with DataSourceRegister { override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } -private[noop] object NoopTable extends Table with SupportsWrite with SupportsStreamingWrite { +private[noop] object NoopTable extends Table with SupportsWrite with BaseStreamingSink { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() +<<<<<<< HEAD override def capabilities(): util.Set[TableCapability] = Set( TableCapability.BATCH_WRITE, TableCapability.TRUNCATE, TableCapability.ACCEPT_ANY_SCHEMA).asJava +||||||| parent of 85fd552ed6... [SPARK-27190][SQL] add table capability for streaming + override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava +======= + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE).asJava + } +>>>>>>> 85fd552ed6... [SPARK-27190][SQL] add table capability for streaming } private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala new file mode 100644 index 0000000000000..c029acc0bb2df --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheck.scala @@ -0,0 +1,64 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.v2.TableCapability.{CONTINUOUS_READ, MICRO_BATCH_READ} + +/** + * This rules adds some basic table capability check for streaming scan, without knowing the actual + * streaming execution mode. + */ +object V2StreamingScanSupportCheck extends (LogicalPlan => Unit) { + import DataSourceV2Implicits._ + + override def apply(plan: LogicalPlan): Unit = { + plan.foreach { + case r: StreamingRelationV2 if !r.table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => + throw new AnalysisException( + s"Table ${r.table.name()} does not support either micro-batch or continuous scan.") + case _ => + } + + val streamingSources = plan.collect { + case r: StreamingRelationV2 => r.table + } + val v1StreamingRelations = plan.collect { + case r: StreamingRelation => r + } + + if (streamingSources.length + v1StreamingRelations.length > 1) { + val allSupportsMicroBatch = streamingSources.forall(_.supports(MICRO_BATCH_READ)) + // v1 streaming data source only supports micro-batch. + val allSupportsContinuous = streamingSources.forall(_.supports(CONTINUOUS_READ)) && + v1StreamingRelations.isEmpty + if (!allSupportsMicroBatch && !allSupportsContinuous) { + val microBatchSources = + streamingSources.filter(_.supports(MICRO_BATCH_READ)).map(_.name()) ++ + v1StreamingRelations.map(_.sourceName) + val continuousSources = streamingSources.filter(_.supports(CONTINUOUS_READ)).map(_.name()) + throw new AnalysisException( + "The streaming sources in a query do not have a common supported execution mode.\n" + + "Sources support micro-batch: " + microBatchSources.mkString(", ") + "\n" + + "Sources support continuous: " + continuousSources.mkString(", ")) + } + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index fdd80ccaf052e..d9fe836b1c494 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -17,7 +17,6 @@ package org.apache.spark.sql.execution.streaming -import scala.collection.JavaConverters._ import scala.collection.mutable.{Map => MutableMap} import org.apache.spark.sql.{Dataset, SparkSession} @@ -78,6 +77,7 @@ class MicroBatchExecution( val disabledSources = sparkSession.sqlContext.conf.disabledV2StreamingMicroBatchReaders.split(",") + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val _logicalPlan = analyzedPlan.transform { case streamingRelation@StreamingRelation(dataSourceV1, sourceName, output) => toExecutionRelationMap.getOrElseUpdate(streamingRelation, { @@ -88,31 +88,33 @@ class MicroBatchExecution( logInfo(s"Using Source [$source] from DataSourceV1 named '$sourceName' [$dataSourceV1]") StreamingExecutionRelation(source, output)(sparkSession) }) - case s @ StreamingRelationV2(ds, dsName, table: SupportsMicroBatchRead, options, output, _) - if !disabledSources.contains(ds.getClass.getCanonicalName) => - v2ToRelationMap.getOrElseUpdate(s, { - // Materialize source to avoid creating it in every batch - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - nextSourceId += 1 - logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") - // TODO: operator pushdown. - val scan = table.newScanBuilder(options).build() - val stream = scan.toMicroBatchStream(metadataPath) - StreamingDataSourceV2Relation(output, scan, stream) - }) - case s @ StreamingRelationV2(ds, dsName, _, _, output, v1Relation) => - v2ToExecutionRelationMap.getOrElseUpdate(s, { - // Materialize source to avoid creating it in every batch - val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" - if (v1Relation.isEmpty) { - throw new UnsupportedOperationException( - s"Data source $dsName does not support microbatch processing.") - } - val source = v1Relation.get.dataSource.createSource(metadataPath) - nextSourceId += 1 - logInfo(s"Using Source [$source] from DataSourceV2 named '$dsName' [$ds]") - StreamingExecutionRelation(source, output)(sparkSession) - }) + + case s @ StreamingRelationV2(src, srcName, table: SupportsRead, options, output, v1) => + val v2Disabled = disabledSources.contains(src.getClass.getCanonicalName) + if (!v2Disabled && table.supports(TableCapability.MICRO_BATCH_READ)) { + v2ToRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + nextSourceId += 1 + logInfo(s"Reading table [$table] from DataSourceV2 named '$srcName' [$src]") + // TODO: operator pushdown. + val scan = table.newScanBuilder(options).build() + val stream = scan.toMicroBatchStream(metadataPath) + StreamingDataSourceV2Relation(output, scan, stream) + }) + } else if (v1.isEmpty) { + throw new UnsupportedOperationException( + s"Data source $srcName does not support microbatch processing.") + } else { + v2ToExecutionRelationMap.getOrElseUpdate(s, { + // Materialize source to avoid creating it in every batch + val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" + val source = v1.get.dataSource.createSource(metadataPath) + nextSourceId += 1 + logInfo(s"Using Source [$source] from DataSourceV2 named '$srcName' [$src]") + StreamingExecutionRelation(source, output)(sparkSession) + }) + } } sources = _logicalPlan.collect { // v1 source @@ -122,8 +124,9 @@ class MicroBatchExecution( } uniqueSources = sources.distinct + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. sink match { - case s: SupportsStreamingWrite => + case s: SupportsWrite => val streamingWrite = createStreamingWrite(s, extraOptions, _logicalPlan) WriteToMicroBatchDataSource(streamingWrite, _logicalPlan) @@ -519,7 +522,7 @@ class MicroBatchExecution( val triggerLogicalPlan = sink match { case _: Sink => newAttributePlan - case _: SupportsStreamingWrite => + case _: SupportsWrite => newAttributePlan.asInstanceOf[WriteToMicroBatchDataSource].createPlan(currentBatchId) case _ => throw new IllegalArgumentException(s"unknown sink type for $sink") } @@ -550,7 +553,7 @@ class MicroBatchExecution( SQLExecution.withNewExecutionId(sparkSessionToRunBatch, lastExecution) { sink match { case s: Sink => s.addBatch(currentBatchId, nextBatch) - case _: SupportsStreamingWrite => + case _: SupportsWrite => // This doesn't accumulate any data - it just forces execution of the microbatch writer. nextBatch.collect() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index cc441937ce70c..fd959619650e4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.SupportsStreamingWrite +import org.apache.spark.sql.sources.v2.SupportsWrite import org.apache.spark.sql.sources.v2.writer.SupportsTruncate import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ @@ -582,7 +582,7 @@ abstract class StreamExecution( } protected def createStreamingWrite( - table: SupportsStreamingWrite, + table: SupportsWrite, options: Map[String, String], inputPlan: LogicalPlan): StreamingWrite = { val writeBuilder = table.newWriteBuilder(new CaseInsensitiveStringMap(options.asJava)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index 884b92ae9421c..c7161d311c028 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.streaming import java.util -import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.sql._ import org.apache.spark.sql.execution.streaming.sources.ConsoleWrite @@ -60,13 +61,15 @@ class ConsoleSinkProvider extends TableProvider def shortName(): String = "console" } -object ConsoleTable extends Table with SupportsStreamingWrite { +object ConsoleTable extends Table with SupportsWrite with BaseStreamingSink { override def name(): String = "console" override def schema(): StructType = StructType(Nil) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index c8fb53df52598..ef0c942e959ea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.execution.SQLExecution import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming.{StreamingRelationV2, _} import org.apache.spark.sql.sources.v2 -import org.apache.spark.sql.sources.v2.{SupportsContinuousRead, SupportsStreamingWrite} +import org.apache.spark.sql.sources.v2.{SupportsRead, SupportsWrite, TableCapability} import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, PartitionOffset} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -42,14 +42,14 @@ class ContinuousExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: SupportsStreamingWrite, + sink: SupportsWrite, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( - sparkSession, name, checkpointRoot, analyzedPlan, sink, + sparkSession, name, checkpointRoot, analyzedPlan, sink.asInstanceOf[BaseStreamingSink], trigger, triggerClock, outputMode, deleteCheckpointOnStop) { @volatile protected var sources: Seq[ContinuousStream] = Seq() @@ -63,22 +63,23 @@ class ContinuousExecution( override val logicalPlan: WriteToContinuousDataSource = { val v2ToRelationMap = MutableMap[StreamingRelationV2, StreamingDataSourceV2Relation]() var nextSourceId = 0 + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ val _logicalPlan = analyzedPlan.transform { - case s @ StreamingRelationV2( - ds, dsName, table: SupportsContinuousRead, options, output, _) => + case s @ StreamingRelationV2(ds, sourceName, table: SupportsRead, options, output, _) => + if (!table.supports(TableCapability.CONTINUOUS_READ)) { + throw new UnsupportedOperationException( + s"Data source $sourceName does not support continuous processing.") + } + v2ToRelationMap.getOrElseUpdate(s, { val metadataPath = s"$resolvedCheckpointRoot/sources/$nextSourceId" nextSourceId += 1 - logInfo(s"Reading table [$table] from DataSourceV2 named '$dsName' [$ds]") + logInfo(s"Reading table [$table] from DataSourceV2 named '$sourceName' [$ds]") // TODO: operator pushdown. val scan = table.newScanBuilder(options).build() val stream = scan.toContinuousStream(metadataPath) StreamingDataSourceV2Relation(output, scan, stream) }) - - case StreamingRelationV2(_, sourceName, _, _, _, _) => - throw new UnsupportedOperationException( - s"Data source $sourceName does not support continuous processing.") } sources = _logicalPlan.collect { @@ -86,6 +87,7 @@ class ContinuousExecution( } uniqueSources = sources.distinct + // TODO (SPARK-27484): we should add the writing node before the plan is analyzed. WriteToContinuousDataSource( createStreamingWrite(sink, extraOptions, _logicalPlan), _logicalPlan) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index bfa9c09985503..0dcbdd3a1fd21 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -18,10 +18,10 @@ package org.apache.spark.sql.execution.streaming import java.util -import java.util.Collections import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable.{ArrayBuffer, ListBuffer} import scala.util.control.NonFatal @@ -92,14 +92,15 @@ object MemoryStreamTableProvider extends TableProvider { } } -class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table - with SupportsMicroBatchRead with SupportsContinuousRead { +class MemoryStreamTable(val stream: MemoryStreamBase[_]) extends Table with SupportsRead { override def name(): String = "MemoryStreamDataSource" override def schema(): StructType = stream.fullSchema() - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { new MemoryStreamScanBuilder(stream) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 807e0b12c6278..838ede6c563f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -18,14 +18,16 @@ package org.apache.spark.sql.execution.streaming.sources import java.util -import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.sql.{ForeachWriter, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, Table, TableCapability} +import org.apache.spark.sql.execution.streaming.BaseStreamingSink +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -42,13 +44,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends Table with SupportsStreamingWrite { + extends Table with SupportsWrite with BaseStreamingSink { override def name(): String = "ForeachSink" override def schema(): StructType = StructType(Nil) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala index 08aea75de2b5a..8dbae9f787cf0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProvider.scala @@ -18,7 +18,8 @@ package org.apache.spark.sql.execution.streaming.sources import java.util -import java.util.Collections + +import scala.collection.JavaConverters._ import org.apache.spark.network.util.JavaUtils import org.apache.spark.sql.SparkSession @@ -78,7 +79,7 @@ class RateStreamTable( rowsPerSecond: Long, rampUpTimeSeconds: Long, numPartitions: Int) - extends Table with SupportsMicroBatchRead with SupportsContinuousRead { + extends Table with SupportsRead { override def name(): String = { s"RateStream(rowsPerSecond=$rowsPerSecond, rampUpTimeSeconds=$rampUpTimeSeconds, " + @@ -87,7 +88,9 @@ class RateStreamTable( override def schema(): StructType = RateStreamProvider.SCHEMA - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala index c0292acdf1044..e714859c16ddd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketSourceProvider.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.execution.streaming.sources import java.text.SimpleDateFormat import java.util -import java.util.{Collections, Locale} +import java.util.Locale +import scala.collection.JavaConverters._ import scala.util.{Failure, Success, Try} import org.apache.spark.internal.Logging @@ -67,7 +68,7 @@ class TextSocketSourceProvider extends TableProvider with DataSourceRegister wit } class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimestamp: Boolean) - extends Table with SupportsMicroBatchRead with SupportsContinuousRead { + extends Table with SupportsRead { override def name(): String = s"Socket[$host:$port]" @@ -79,7 +80,9 @@ class TextSocketTable(host: String, port: Int, numPartitions: Int, includeTimest } } - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.MICRO_BATCH_READ, TableCapability.CONTINUOUS_READ).asJava + } override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new ScanBuilder { override def build(): Scan = new Scan { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala index 8eb5de0f640a4..219e25c1407b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala @@ -18,9 +18,9 @@ package org.apache.spark.sql.execution.streaming.sources import java.util -import java.util.Collections import javax.annotation.concurrent.GuardedBy +import scala.collection.JavaConverters._ import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} -import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, TableCapability} +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} import org.apache.spark.sql.types.StructType @@ -43,13 +43,15 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends SupportsStreamingWrite with MemorySinkBase with Logging { +class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Logging { override def name(): String = "MemorySinkV2" override def schema(): StructType = StructType(Nil) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(TableCapability.STREAMING_WRITE).asJava + } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new WriteBuilder with SupportsTruncate { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 588f5dde85930..18029abb08dab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck +import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.streaming.StreamingQueryManager import org.apache.spark.sql.util.ExecutionListenerManager @@ -175,6 +175,7 @@ abstract class BaseSessionStateBuilder( PreReadCheck +: HiveOnlyCheck +: V2WriteSupportCheck +: + V2StreamingScanSupportCheck +: customCheckRules } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala index cce8fcd3012bf..da4723e34c0d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamReader.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming.{StreamingRelation, StreamingRelationV2} import org.apache.spark.sql.sources.StreamSourceProvider import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -181,8 +182,9 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo case Some(schema) => provider.getTable(dsOptions, schema) case _ => provider.getTable(dsOptions) } + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ table match { - case _: SupportsMicroBatchRead | _: SupportsContinuousRead => + case _: SupportsRead if table.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => Dataset.ofRows( sparkSession, StreamingRelationV2( @@ -190,6 +192,7 @@ final class DataStreamReader private[sql](sparkSession: SparkSession) extends Lo sparkSession)) // fallback to v1 + // TODO (SPARK-27483): we should move this fallback logic to an analyzer rule. case _ => Dataset.ofRows(sparkSession, StreamingRelation(v1DataSource)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 33d032eb78c2b..d2df3a5349dd5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -31,7 +31,8 @@ import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Utils import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.sources.v2.{SupportsStreamingWrite, TableProvider} +import org.apache.spark.sql.sources.v2.{SupportsWrite, TableProvider} +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.util.CaseInsensitiveStringMap /** @@ -315,8 +316,10 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { source = provider, conf = df.sparkSession.sessionState.conf) val options = sessionOptions ++ extraOptions val dsOptions = new CaseInsensitiveStringMap(options.asJava) + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { - case s: SupportsStreamingWrite => s + case table: SupportsWrite if table.supports(STREAMING_WRITE) => + table.asInstanceOf[BaseStreamingSink] case _ => createV1Sink() } } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 9d86cac9cec5b..040f1723fb93b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.SupportsStreamingWrite +import org.apache.spark.sql.sources.v2.SupportsWrite import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -258,7 +258,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo } (sink, trigger) match { - case (v2Sink: SupportsStreamingWrite, trigger: ContinuousTrigger) => + case (table: SupportsWrite, trigger: ContinuousTrigger) => if (operationCheckEnabled) { UnsupportedOperationChecker.checkForContinuous(analyzedPlan, outputMode) } @@ -267,7 +267,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedName.orNull, checkpointLocation, analyzedPlan, - v2Sink, + table, trigger, triggerClock, outputMode, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala new file mode 100644 index 0000000000000..8a0450fce76a1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/V2StreamingScanSupportCheckSuite.scala @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.v2 + +import java.util + +import scala.collection.JavaConverters._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.catalyst.plans.logical.Union +import org.apache.spark.sql.execution.datasources.DataSource +import org.apache.spark.sql.execution.streaming.{Offset, Source, StreamingRelation, StreamingRelationV2} +import org.apache.spark.sql.sources.StreamSourceProvider +import org.apache.spark.sql.sources.v2.{Table, TableCapability, TableProvider} +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.CaseInsensitiveStringMap + +class V2StreamingScanSupportCheckSuite extends SparkFunSuite with SharedSparkSession { + import TableCapability._ + + private def createStreamingRelation(table: Table, v1Relation: Option[StreamingRelation]) = { + StreamingRelationV2(FakeTableProvider, "fake", table, CaseInsensitiveStringMap.empty(), + FakeTableProvider.schema.toAttributes, v1Relation)(spark) + } + + private def createStreamingRelationV1() = { + StreamingRelation(DataSource(spark, classOf[FakeStreamSourceProvider].getName)) + } + + test("check correct plan") { + val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None) + val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + val plan3 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ, CONTINUOUS_READ), None) + val plan4 = createStreamingRelationV1() + + V2StreamingScanSupportCheck(Union(plan1, plan1)) + V2StreamingScanSupportCheck(Union(plan2, plan2)) + V2StreamingScanSupportCheck(Union(plan1, plan3)) + V2StreamingScanSupportCheck(Union(plan2, plan3)) + V2StreamingScanSupportCheck(Union(plan1, plan4)) + V2StreamingScanSupportCheck(Union(plan3, plan4)) + } + + test("table without scan capability") { + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(createStreamingRelation(CapabilityTable(), None)) + } + assert(e.message.contains("does not support either micro-batch or continuous scan")) + } + + test("mix micro-batch only and continuous only") { + val plan1 = createStreamingRelation(CapabilityTable(MICRO_BATCH_READ), None) + val plan2 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(Union(plan1, plan2)) + } + assert(e.message.contains( + "The streaming sources in a query do not have a common supported execution mode")) + } + + test("mix continuous only and v1 relation") { + val plan1 = createStreamingRelation(CapabilityTable(CONTINUOUS_READ), None) + val plan2 = createStreamingRelationV1() + val e = intercept[AnalysisException] { + V2StreamingScanSupportCheck(Union(plan1, plan2)) + } + assert(e.message.contains( + "The streaming sources in a query do not have a common supported execution mode")) + } +} + +private object FakeTableProvider extends TableProvider { + val schema = new StructType().add("i", "int") + + override def getTable(options: CaseInsensitiveStringMap): Table = { + throw new UnsupportedOperationException + } +} + +private case class CapabilityTable(_capabilities: TableCapability*) extends Table { + override def name(): String = "capability_test_table" + override def schema(): StructType = FakeTableProvider.schema + override def capabilities(): util.Set[TableCapability] = _capabilities.toSet.asJava +} + +private class FakeStreamSourceProvider extends StreamSourceProvider { + override def sourceSchema( + sqlContext: SQLContext, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): (String, StructType) = { + "fake" -> FakeTableProvider.schema + } + + override def createSource( + sqlContext: SQLContext, + metadataPath: String, + schema: Option[StructType], + providerName: String, + parameters: Map[String, String]): Source = { + new Source { + override def schema: StructType = FakeTableProvider.schema + override def getOffset: Option[Offset] = { + throw new UnsupportedOperationException + } + override def getBatch(start: Option[Offset], end: Offset): DataFrame = { + throw new UnsupportedOperationException + } + override def stop(): Unit = {} + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index f022edea275e0..25a68e4f9a57c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -20,13 +20,16 @@ package org.apache.spark.sql.streaming.sources import java.util import java.util.Collections -import org.apache.spark.sql.{DataFrame, SQLContext} +import scala.collection.JavaConverters._ + +import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.{BaseStreamingSink, RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} import org.apache.spark.sql.sources.v2._ +import org.apache.spark.sql.sources.v2.TableCapability._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming._ import org.apache.spark.sql.sources.v2.writer.{WriteBuilder, WriterCommitMessage} @@ -77,24 +80,12 @@ class FakeWriteBuilder extends WriteBuilder with StreamingWrite { } } -trait FakeMicroBatchReadTable extends Table with SupportsMicroBatchRead { - override def name(): String = "fake" - override def schema(): StructType = StructType(Seq()) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder -} - -trait FakeContinuousReadTable extends Table with SupportsContinuousRead { - override def name(): String = "fake" - override def schema(): StructType = StructType(Seq()) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() - override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = new FakeScanBuilder -} - -trait FakeStreamingWriteTable extends Table with SupportsStreamingWrite { +trait FakeStreamingWriteTable extends Table with SupportsWrite with BaseStreamingSink { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) - override def capabilities(): util.Set[TableCapability] = Collections.emptySet() + override def capabilities(): util.Set[TableCapability] = { + Set(STREAMING_WRITE).asJava + } override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = { new FakeWriteBuilder } @@ -110,7 +101,16 @@ class FakeReadMicroBatchOnly override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options - new FakeMicroBatchReadTable {} + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(MICRO_BATCH_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } @@ -124,7 +124,16 @@ class FakeReadContinuousOnly override def getTable(options: CaseInsensitiveStringMap): Table = { LastReadOptions.options = options - new FakeContinuousReadTable {} + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(CONTINUOUS_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } @@ -132,7 +141,16 @@ class FakeReadBothModes extends DataSourceRegister with TableProvider { override def shortName(): String = "fake-read-microbatch-continuous" override def getTable(options: CaseInsensitiveStringMap): Table = { - new Table with FakeMicroBatchReadTable with FakeContinuousReadTable {} + new Table with SupportsRead { + override def name(): String = "fake" + override def schema(): StructType = StructType(Seq()) + override def capabilities(): util.Set[TableCapability] = { + Set(MICRO_BATCH_READ, CONTINUOUS_READ).asJava + } + override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = { + new FakeScanBuilder + } + } } } @@ -365,39 +383,37 @@ class StreamingDataSourceV2Suite extends StreamTest { val sinkTable = DataSource.lookupDataSource(write, spark.sqlContext.conf).getConstructor() .newInstance().asInstanceOf[TableProvider].getTable(CaseInsensitiveStringMap.empty()) - (sourceTable, sinkTable, trigger) match { - // Valid microbatch queries. - case (_: SupportsMicroBatchRead, _: SupportsStreamingWrite, t) - if !t.isInstanceOf[ContinuousTrigger] => - testPositiveCase(read, write, trigger) - - // Valid continuous queries. - case (_: SupportsContinuousRead, _: SupportsStreamingWrite, - _: ContinuousTrigger) => - testPositiveCase(read, write, trigger) - + import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ + trigger match { // Invalid - can't read at all - case (r, _, _) if !r.isInstanceOf[SupportsMicroBatchRead] && - !r.isInstanceOf[SupportsContinuousRead] => + case _ if !sourceTable.supportsAny(MICRO_BATCH_READ, CONTINUOUS_READ) => testNegativeCase(read, write, trigger, s"Data source $read does not support streamed reading") // Invalid - can't write - case (_, w, _) if !w.isInstanceOf[SupportsStreamingWrite] => + case _ if !sinkTable.supports(STREAMING_WRITE) => testNegativeCase(read, write, trigger, s"Data source $write does not support streamed writing") - // Invalid - trigger is continuous but reader is not - case (r, _: SupportsStreamingWrite, _: ContinuousTrigger) - if !r.isInstanceOf[SupportsContinuousRead] => - testNegativeCase(read, write, trigger, - s"Data source $read does not support continuous processing") + case _: ContinuousTrigger => + if (sourceTable.supports(CONTINUOUS_READ)) { + // Valid microbatch queries. + testPositiveCase(read, write, trigger) + } else { + // Invalid - trigger is continuous but reader is not + testNegativeCase( + read, write, trigger, s"Data source $read does not support continuous processing") + } - // Invalid - trigger is microbatch but reader is not - case (r, _, t) if !r.isInstanceOf[SupportsMicroBatchRead] && - !t.isInstanceOf[ContinuousTrigger] => - testPostCreationNegativeCase(read, write, trigger, - s"Data source $read does not support microbatch processing") + case microBatchTrigger => + if (sourceTable.supports(MICRO_BATCH_READ)) { + // Valid continuous queries. + testPositiveCase(read, write, trigger) + } else { + // Invalid - trigger is microbatch but reader is not + testPostCreationNegativeCase(read, write, trigger, + s"Data source $read does not support microbatch processing") + } } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 23c777ea1030b..84e5fae79bf16 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.SparkPlanner import org.apache.spark.sql.execution.datasources._ -import org.apache.spark.sql.execution.datasources.v2.V2WriteSupportCheck +import org.apache.spark.sql.execution.datasources.v2.{V2StreamingScanSupportCheck, V2WriteSupportCheck} import org.apache.spark.sql.hive.client.HiveClient import org.apache.spark.sql.internal.{BaseSessionStateBuilder, SessionResourceLoader, SessionState} @@ -89,6 +89,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session PreWriteCheck +: PreReadCheck +: V2WriteSupportCheck +: + V2StreamingScanSupportCheck +: customCheckRules } From a022526dbe0727ed9b269edfb1683059fa45878f Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:00:28 -0700 Subject: [PATCH 56/70] Fix merge conflicts --- .../execution/datasources/noop/NoopDataSource.scala | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 296a796890c77..321d006986f1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -43,18 +43,11 @@ private[noop] object NoopTable extends Table with SupportsWrite with BaseStreami override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() -<<<<<<< HEAD override def capabilities(): util.Set[TableCapability] = Set( TableCapability.BATCH_WRITE, TableCapability.TRUNCATE, - TableCapability.ACCEPT_ANY_SCHEMA).asJava -||||||| parent of 85fd552ed6... [SPARK-27190][SQL] add table capability for streaming - override def capabilities(): util.Set[TableCapability] = Set(TableCapability.BATCH_WRITE).asJava -======= - override def capabilities(): util.Set[TableCapability] = { - Set(TableCapability.BATCH_WRITE, TableCapability.STREAMING_WRITE).asJava - } ->>>>>>> 85fd552ed6... [SPARK-27190][SQL] add table capability for streaming + TableCapability.ACCEPT_ANY_SCHEMA, + TableCapability.STREAMING_WRITE).asJava } private[noop] object NoopWriteBuilder extends WriteBuilder with SupportsTruncate { From 4e5087f625b12d23b64ad4b0e88d9ed16fc55991 Mon Sep 17 00:00:00 2001 From: Gabor Somogyi Date: Mon, 29 Apr 2019 09:44:23 -0700 Subject: [PATCH 57/70] [SPARK-23014][SS] Fully remove V1 memory sink. There is a MemorySink v2 already so v1 can be removed. In this PR I've removed it completely. What this PR contains: * V1 memory sink removal * V2 memory sink renamed to become the only implementation * Since DSv2 sends exceptions in a chained format (linking them with cause field) I've made python side compliant * Adapted all the tests Existing unit tests. Closes #24403 from gaborgsomogyi/SPARK-23014. Authored-by: Gabor Somogyi Signed-off-by: Marcelo Vanzin --- .../scala/org/apache/spark/TestUtils.scala | 14 +++ .../sql/kafka010/KafkaContinuousTest.scala | 1 - .../spark/ml/recommendation/ALSSuite.scala | 10 +- .../org/apache/spark/ml/util/MLTest.scala | 10 +- python/pyspark/sql/streaming.py | 2 +- python/pyspark/sql/tests/test_streaming.py | 12 ++- python/pyspark/sql/utils.py | 49 ++++++---- .../spark/sql/execution/SparkStrategies.scala | 5 +- .../sql/execution/streaming/memory.scala | 92 +------------------ .../sources/{memoryV2.scala => memory.scala} | 14 +-- .../sql/streaming/DataStreamWriter.scala | 12 +-- .../execution/streaming/MemorySinkSuite.scala | 88 ++++++++++++++---- .../streaming/MemorySinkV2Suite.scala | 66 ------------- .../streaming/EventTimeWatermarkSuite.scala | 1 + .../sql/streaming/FileStreamSourceSuite.scala | 1 + .../spark/sql/streaming/StreamSuite.scala | 16 ++-- .../spark/sql/streaming/StreamTest.scala | 16 +--- .../streaming/StreamingAggregationSuite.scala | 7 ++ .../StreamingQueryListenerSuite.scala | 2 +- .../sql/streaming/StreamingQuerySuite.scala | 22 ++--- ...ontinuousQueryStatusAndProgressSuite.scala | 2 +- .../continuous/ContinuousSuite.scala | 9 +- 22 files changed, 187 insertions(+), 264 deletions(-) rename sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/{memoryV2.scala => memory.scala} (93%) delete mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala diff --git a/core/src/main/scala/org/apache/spark/TestUtils.scala b/core/src/main/scala/org/apache/spark/TestUtils.scala index c2ebd388a2365..c97b10ee63b18 100644 --- a/core/src/main/scala/org/apache/spark/TestUtils.scala +++ b/core/src/main/scala/org/apache/spark/TestUtils.scala @@ -192,6 +192,20 @@ private[spark] object TestUtils { assert(listener.numSpilledStages == 0, s"expected $identifier to not spill, but did") } + /** + * Asserts that exception message contains the message. Please note this checks all + * exceptions in the tree. + */ + def assertExceptionMsg(exception: Throwable, msg: String): Unit = { + var e = exception + var contains = e.getMessage.contains(msg) + while (e.getCause != null && !contains) { + e = e.getCause + contains = e.getMessage.contains(msg) + } + assert(contains, s"Exception tree doesn't contain the expected message: $msg") + } + /** * Test if a command is available. */ diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala index ad1c2c59d9c8e..9ee8cbfa1bef4 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaContinuousTest.scala @@ -30,7 +30,6 @@ import org.apache.spark.sql.test.TestSparkSession // Trait to configure StreamTest for kafka continuous execution tests. trait KafkaContinuousTest extends KafkaSourceTest { override val defaultTrigger = Trigger.Continuous(1000) - override val defaultUseV2Sink = true // We need more than the default local[2] to be able to schedule all partitions simultaneously. override protected def createSparkSession = new TestSparkSession( diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index 2fc9754ecfe1e..4cc467a6d664a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -696,12 +696,14 @@ class ALSSuite extends MLTest with DefaultReadWriteTest with Logging { withClue("transform should fail when ids exceed integer range. ") { val model = als.fit(df) def testTransformIdExceedsIntRange[A : Encoder](dataFrame: DataFrame): Unit = { - assert(intercept[SparkException] { + val e1 = intercept[SparkException] { model.transform(dataFrame).first - }.getMessage.contains(msg)) - assert(intercept[StreamingQueryException] { + } + TestUtils.assertExceptionMsg(e1, msg) + val e2 = intercept[StreamingQueryException] { testTransformer[A](dataFrame, model, "prediction") { _ => } - }.getMessage.contains(msg)) + } + TestUtils.assertExceptionMsg(e2, msg) } testTransformIdExceedsIntRange[(Long, Int)](df.select(df("user_big").as("user"), df("item"))) diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala index 514fa7f2e1b8d..0861a3a2d099e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala @@ -21,7 +21,7 @@ import java.io.File import org.scalatest.Suite -import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext} +import org.apache.spark.{DebugFilesystem, SparkConf, SparkContext, TestUtils} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK import org.apache.spark.ml.{PredictionModel, Transformer} import org.apache.spark.ml.linalg.Vector @@ -129,21 +129,17 @@ trait MLTest extends StreamTest with TempDirectory { self: Suite => expectedMessagePart : String, firstResultCol: String) { - def hasExpectedMessage(exception: Throwable): Boolean = - exception.getMessage.contains(expectedMessagePart) || - (exception.getCause != null && exception.getCause.getMessage.contains(expectedMessagePart)) - withClue(s"""Expected message part "${expectedMessagePart}" is not found in DF test.""") { val exceptionOnDf = intercept[Throwable] { testTransformerOnDF(dataframe, transformer, firstResultCol)(_ => Unit) } - assert(hasExpectedMessage(exceptionOnDf)) + TestUtils.assertExceptionMsg(exceptionOnDf, expectedMessagePart) } withClue(s"""Expected message part "${expectedMessagePart}" is not found in stream test.""") { val exceptionOnStreamData = intercept[Throwable] { testTransformerOnStreamData(dataframe, transformer, firstResultCol)(_ => Unit) } - assert(hasExpectedMessage(exceptionOnStreamData)) + TestUtils.assertExceptionMsg(exceptionOnStreamData, expectedMessagePart) } } diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fc23b9d99c34a..6cc47ccdbd431 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -186,7 +186,7 @@ def exception(self): je = self._jsq.exception().get() msg = je.toString().split(': ', 1)[1] # Drop the Java StreamingQueryException type info stackTrace = '\n\t at '.join(map(lambda x: x.toString(), je.getStackTrace())) - return StreamingQueryException(msg, stackTrace) + return StreamingQueryException(msg, stackTrace, je.getCause()) else: return None diff --git a/python/pyspark/sql/tests/test_streaming.py b/python/pyspark/sql/tests/test_streaming.py index 4b71759f74a55..1bd81c4411202 100644 --- a/python/pyspark/sql/tests/test_streaming.py +++ b/python/pyspark/sql/tests/test_streaming.py @@ -224,11 +224,19 @@ def test_stream_exception(self): self.fail("bad udf should fail the query") except StreamingQueryException as e: # This is expected - self.assertTrue("ZeroDivisionError" in e.desc) + self._assert_exception_tree_contains_msg(e, "ZeroDivisionError") finally: sq.stop() self.assertTrue(type(sq.exception()) is StreamingQueryException) - self.assertTrue("ZeroDivisionError" in sq.exception().desc) + self._assert_exception_tree_contains_msg(sq.exception(), "ZeroDivisionError") + + def _assert_exception_tree_contains_msg(self, exception, msg): + e = exception + contains = msg in e.desc + while e.cause is not None and not contains: + e = e.cause + contains = msg in e.desc + self.assertTrue(contains, "Exception tree doesn't contain the expected message: %s" % msg) def test_query_manager_await_termination(self): df = self.spark.readStream.format('text').load('python/test_support/sql/streaming') diff --git a/python/pyspark/sql/utils.py b/python/pyspark/sql/utils.py index bdb3a1467f1d8..b80bc4822b6f6 100644 --- a/python/pyspark/sql/utils.py +++ b/python/pyspark/sql/utils.py @@ -19,9 +19,10 @@ class CapturedException(Exception): - def __init__(self, desc, stackTrace): + def __init__(self, desc, stackTrace, cause=None): self.desc = desc self.stackTrace = stackTrace + self.cause = convert_exception(cause) if cause is not None else None def __str__(self): return repr(self.desc) @@ -57,27 +58,41 @@ class QueryExecutionException(CapturedException): """ +class UnknownException(CapturedException): + """ + None of the above exceptions. + """ + + +def convert_exception(e): + s = e.toString() + stackTrace = '\n\t at '.join(map(lambda x: x.toString(), e.getStackTrace())) + c = e.getCause() + if s.startswith('org.apache.spark.sql.AnalysisException: '): + return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.catalyst.analysis'): + return AnalysisException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): + return ParseException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): + return StreamingQueryException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): + return QueryExecutionException(s.split(': ', 1)[1], stackTrace, c) + if s.startswith('java.lang.IllegalArgumentException: '): + return IllegalArgumentException(s.split(': ', 1)[1], stackTrace, c) + return UnknownException(s, stackTrace, c) + + def capture_sql_exception(f): def deco(*a, **kw): try: return f(*a, **kw) except py4j.protocol.Py4JJavaError as e: - s = e.java_exception.toString() - stackTrace = '\n\t at '.join(map(lambda x: x.toString(), - e.java_exception.getStackTrace())) - if s.startswith('org.apache.spark.sql.AnalysisException: '): - raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.analysis'): - raise AnalysisException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.catalyst.parser.ParseException: '): - raise ParseException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.streaming.StreamingQueryException: '): - raise StreamingQueryException(s.split(': ', 1)[1], stackTrace) - if s.startswith('org.apache.spark.sql.execution.QueryExecutionException: '): - raise QueryExecutionException(s.split(': ', 1)[1], stackTrace) - if s.startswith('java.lang.IllegalArgumentException: '): - raise IllegalArgumentException(s.split(': ', 1)[1], stackTrace) - raise + converted = convert_exception(e.java_exception) + if not isinstance(converted, UnknownException): + raise converted + else: + raise return deco diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index edfa70403ad15..e72ddf13f1668 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide} import org.apache.spark.sql.execution.python._ import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.MemoryPlanV2 +import org.apache.spark.sql.execution.streaming.sources.MemoryPlan import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{OutputMode, StreamingQuery} import org.apache.spark.sql.types.StructType @@ -557,9 +557,6 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case r: RunnableCommand => ExecutedCommandExec(r) :: Nil case MemoryPlan(sink, output) => - val encoder = RowEncoder(sink.schema) - LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil - case MemoryPlanV2(sink, output) => val encoder = RowEncoder(StructType.fromAttributes(output)) LocalTableScanExec(output, sink.allData.map(r => encoder.toRow(r).copy())) :: Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 0dcbdd3a1fd21..6efde0a27efe9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -22,23 +22,19 @@ import java.util.concurrent.atomic.AtomicInteger import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ -import scala.collection.mutable.{ArrayBuffer, ListBuffer} -import scala.util.control.NonFatal +import scala.collection.mutable.ListBuffer import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.encoderFor -import org.apache.spark.sql.catalyst.expressions.{Attribute, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Statistics} -import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.catalyst.streaming.InternalOutputModes._ +import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2} -import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -276,85 +272,3 @@ trait MemorySinkBase extends BaseStreamingSink { def dataSinceBatch(sinceBatchId: Long): Seq[Row] def latestBatchId: Option[Long] } - -/** - * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit - * tests and does not provide durability. - */ -class MemorySink(val schema: StructType, outputMode: OutputMode) extends Sink - with MemorySinkBase with Logging { - - private case class AddedData(batchId: Long, data: Array[Row]) - - /** An order list of batches that have been written to this [[Sink]]. */ - @GuardedBy("this") - private val batches = new ArrayBuffer[AddedData]() - - /** Returns all rows that are stored in this [[Sink]]. */ - def allData: Seq[Row] = synchronized { - batches.flatMap(_.data) - } - - def latestBatchId: Option[Long] = synchronized { - batches.lastOption.map(_.batchId) - } - - def latestBatchData: Seq[Row] = synchronized { batches.lastOption.toSeq.flatten(_.data) } - - def dataSinceBatch(sinceBatchId: Long): Seq[Row] = synchronized { - batches.filter(_.batchId > sinceBatchId).flatMap(_.data) - } - - def toDebugString: String = synchronized { - batches.map { case AddedData(batchId, data) => - val dataStr = try data.mkString(" ") catch { - case NonFatal(e) => "[Error converting to string]" - } - s"$batchId: $dataStr" - }.mkString("\n") - } - - override def addBatch(batchId: Long, data: DataFrame): Unit = { - val notCommitted = synchronized { - latestBatchId.isEmpty || batchId > latestBatchId.get - } - if (notCommitted) { - logDebug(s"Committing batch $batchId to $this") - outputMode match { - case Append | Update => - val rows = AddedData(batchId, data.collect()) - synchronized { batches += rows } - - case Complete => - val rows = AddedData(batchId, data.collect()) - synchronized { - batches.clear() - batches += rows - } - - case _ => - throw new IllegalArgumentException( - s"Output mode $outputMode is not supported by MemorySink") - } - } else { - logDebug(s"Skipping already committed batch: $batchId") - } - } - - def clear(): Unit = synchronized { - batches.clear() - } - - override def toString(): String = "MemorySink" -} - -/** - * Used to query the data that has been written into a [[MemorySink]]. - */ -case class MemoryPlan(sink: MemorySink, output: Seq[Attribute]) extends LeafNode { - def this(sink: MemorySink) = this(sink, sink.schema.toAttributes) - - private val sizePerRow = EstimationUtils.getSizePerRow(sink.schema.toAttributes) - - override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala similarity index 93% rename from sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala rename to sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 219e25c1407b7..9008c63491cb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memoryV2.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -43,9 +43,9 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Logging { +class MemorySink extends Table with SupportsWrite with MemorySinkBase with Logging { - override def name(): String = "MemorySinkV2" + override def name(): String = "MemorySink" override def schema(): StructType = StructType(Nil) @@ -69,7 +69,7 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log } override def buildForStreaming(): StreamingWrite = { - new MemoryStreamingWrite(MemorySinkV2.this, inputSchema, needTruncate) + new MemoryStreamingWrite(MemorySink.this, inputSchema, needTruncate) } } } @@ -130,14 +130,14 @@ class MemorySinkV2 extends Table with SupportsWrite with MemorySinkBase with Log batches.clear() } - override def toString(): String = "MemorySinkV2" + override def toString(): String = "MemorySink" } case class MemoryWriterCommitMessage(partition: Int, data: Seq[Row]) extends WriterCommitMessage {} class MemoryStreamingWrite( - val sink: MemorySinkV2, schema: StructType, needTruncate: Boolean) + val sink: MemorySink, schema: StructType, needTruncate: Boolean) extends StreamingWrite { override def createStreamingWriterFactory: MemoryWriterFactory = { @@ -195,9 +195,9 @@ class MemoryDataWriter(partition: Int, schema: StructType) /** - * Used to query the data that has been written into a [[MemorySinkV2]]. + * Used to query the data that has been written into a [[MemorySink]]. */ -case class MemoryPlanV2(sink: MemorySinkV2, override val output: Seq[Attribute]) extends LeafNode { +case class MemoryPlan(sink: MemorySink, override val output: Seq[Attribute]) extends LeafNode { private val sizePerRow = EstimationUtils.getSizePerRow(output) override def computeStats(): Statistics = Statistics(sizePerRow * sink.allData.size) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index d2df3a5349dd5..2f12efe04c507 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -254,16 +254,8 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { if (extraOptions.get("queryName").isEmpty) { throw new AnalysisException("queryName must be specified for memory sink") } - val (sink, resultDf) = trigger match { - case _: ContinuousTrigger => - val s = new MemorySinkV2() - val r = Dataset.ofRows(df.sparkSession, new MemoryPlanV2(s, df.schema.toAttributes)) - (s, r) - case _ => - val s = new MemorySink(df.schema, outputMode) - val r = Dataset.ofRows(df.sparkSession, new MemoryPlan(s)) - (s, r) - } + val sink = new MemorySink() + val resultDf = Dataset.ofRows(df.sparkSession, new MemoryPlan(sink, df.schema.toAttributes)) val chkpointLoc = extraOptions.get("checkpointLocation") val recoverFromChkpoint = outputMode == OutputMode.Complete() val query = df.sparkSession.sessionState.streamingQueryManager.startQuery( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 3bc36ce55d902..3ead91fcf712a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -22,6 +22,8 @@ import scala.language.implicitConversions import org.scalatest.BeforeAndAfter import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.streaming.sources._ import org.apache.spark.sql.streaming.{OutputMode, StreamTest} import org.apache.spark.sql.types.{IntegerType, StructField, StructType} import org.apache.spark.util.Utils @@ -36,7 +38,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Append output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -44,25 +47,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 1 to 9) @@ -70,7 +73,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Update output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Update) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -78,25 +82,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // new data should get appended to old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 1 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 1 to 9) @@ -104,7 +108,8 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("directly add data in Complete output mode") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Complete) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, true) _ // Before adding data, check output assert(sink.latestBatchId === None) @@ -112,25 +117,25 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { checkAnswer(sink.allData, Seq.empty) // Add batch 0 and check outputs - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) assert(sink.latestBatchId === Some(0)) checkAnswer(sink.latestBatchData, 1 to 3) checkAnswer(sink.allData, 1 to 3) // Add batch 1 and check outputs - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 4 to 6) // new data should replace old data // Re-add batch 1 with different data, should not be added and outputs should not be changed - sink.addBatch(1, 7 to 9) + addBatch(1, 7 to 9) assert(sink.latestBatchId === Some(1)) checkAnswer(sink.latestBatchData, 4 to 6) checkAnswer(sink.allData, 4 to 6) // Add batch 2 and check outputs - sink.addBatch(2, 7 to 9) + addBatch(2, 7 to 9) assert(sink.latestBatchId === Some(2)) checkAnswer(sink.latestBatchData, 7 to 9) checkAnswer(sink.allData, 7 to 9) @@ -211,18 +216,19 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { test("MemoryPlan statistics") { implicit val schema = new StructType().add(new StructField("value", IntegerType)) - val sink = new MemorySink(schema, OutputMode.Append) - val plan = new MemoryPlan(sink) + val sink = new MemorySink + val addBatch = addBatchFunc(sink, false) _ + val plan = new MemoryPlan(sink, schema.toAttributes) // Before adding data, check output checkAnswer(sink.allData, Seq.empty) assert(plan.stats.sizeInBytes === 0) - sink.addBatch(0, 1 to 3) + addBatch(0, 1 to 3) plan.invalidateStatsCache() assert(plan.stats.sizeInBytes === 36) - sink.addBatch(1, 4 to 6) + addBatch(1, 4 to 6) plan.invalidateStatsCache() assert(plan.stats.sizeInBytes === 72) } @@ -285,6 +291,50 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { } } + test("data writer") { + val partition = 1234 + val writer = new MemoryDataWriter( + partition, new StructType().add("i", "int")) + writer.write(InternalRow(1)) + writer.write(InternalRow(2)) + writer.write(InternalRow(44)) + val msg = writer.commit() + assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) + assert(msg.partition == partition) + + // Buffer should be cleared, so repeated commits should give empty. + assert(writer.commit().data.isEmpty) + } + + test("streaming writer") { + val sink = new MemorySink + val write = new MemoryStreamingWrite( + sink, new StructType().add("i", "int"), needTruncate = false) + write.commit(0, + Array( + MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), + MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), + MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) + )) + assert(sink.latestBatchId.contains(0)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) + write.commit(19, + Array( + MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), + MemoryWriterCommitMessage(0, Seq(Row(33))) + )) + assert(sink.latestBatchId.contains(19)) + assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) + + assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) + } + + private def addBatchFunc(sink: MemorySink, needTruncate: Boolean)( + batchId: Long, + vals: Seq[Int]): Unit = { + sink.write(batchId, needTruncate, vals.map(Row(_)).toArray) + } + private def checkAnswer(rows: Seq[Row], expected: Seq[Int])(implicit schema: StructType): Unit = { checkAnswer( sqlContext.createDataFrame(sparkContext.makeRDD(rows), schema), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala deleted file mode 100644 index a90acf85c0161..0000000000000 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkV2Suite.scala +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming - -import org.scalatest.BeforeAndAfter - -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.sources._ -import org.apache.spark.sql.streaming.{OutputMode, StreamTest} -import org.apache.spark.sql.types.StructType - -class MemorySinkV2Suite extends StreamTest with BeforeAndAfter { - test("data writer") { - val partition = 1234 - val writer = new MemoryDataWriter( - partition, new StructType().add("i", "int")) - writer.write(InternalRow(1)) - writer.write(InternalRow(2)) - writer.write(InternalRow(44)) - val msg = writer.commit() - assert(msg.data.map(_.getInt(0)) == Seq(1, 2, 44)) - assert(msg.partition == partition) - - // Buffer should be cleared, so repeated commits should give empty. - assert(writer.commit().data.isEmpty) - } - - test("streaming writer") { - val sink = new MemorySinkV2 - val write = new MemoryStreamingWrite( - sink, new StructType().add("i", "int"), needTruncate = false) - write.commit(0, - Array( - MemoryWriterCommitMessage(0, Seq(Row(1), Row(2))), - MemoryWriterCommitMessage(1, Seq(Row(3), Row(4))), - MemoryWriterCommitMessage(2, Seq(Row(6), Row(7))) - )) - assert(sink.latestBatchId.contains(0)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7)) - write.commit(19, - Array( - MemoryWriterCommitMessage(3, Seq(Row(11), Row(22))), - MemoryWriterCommitMessage(0, Seq(Row(33))) - )) - assert(sink.latestBatchId.contains(19)) - assert(sink.latestBatchData.map(_.getInt(0)).sorted == Seq(11, 22, 33)) - - assert(sink.allData.map(_.getInt(0)).sorted == Seq(1, 2, 3, 4, 6, 7, 11, 22, 33)) - } -} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c696204cecc2c..a0a55c08ff018 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.{AnalysisException, Dataset} import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.streaming._ +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.OutputMode._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala index 9235c6d7c896f..0736c6ef00eed 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/FileStreamSourceSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.FileStreamSource.{FileEntry, SeenFilesMap} +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.ExistsThrowsExceptionFileSystem._ import org.apache.spark.sql.streaming.util.StreamManualClock diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index 659deb8cbb51e..f229b08a20aa0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -29,7 +29,7 @@ import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration import org.scalatest.time.SpanSugar._ -import org.apache.spark.{SparkConf, SparkContext, TaskContext} +import org.apache.spark.{SparkConf, SparkContext, TaskContext, TestUtils} import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.plans.logical.Range @@ -37,7 +37,7 @@ import org.apache.spark.sql.catalyst.streaming.InternalOutputModes import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.command.ExplainCommand import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.ContinuousMemoryStream +import org.apache.spark.sql.execution.streaming.sources.{ContinuousMemoryStream, MemorySink} import org.apache.spark.sql.execution.streaming.state.{StateStore, StateStoreConf, StateStoreId, StateStoreProvider} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf @@ -876,8 +876,8 @@ class StreamSuite extends StreamTest { query.awaitTermination() } - assert(e.getMessage.contains(providerClassName)) - assert(e.getMessage.contains("instantiated")) + TestUtils.assertExceptionMsg(e, providerClassName) + TestUtils.assertExceptionMsg(e, "instantiated") } } @@ -1083,15 +1083,15 @@ class StreamSuite extends StreamTest { test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " + " to Dataset - use v2 sink") { - testCurrentTimestampOnStreamingQuery(useV2Sink = true) + testCurrentTimestampOnStreamingQuery() } test("SPARK-26379 Structured Streaming - Exception on adding current_timestamp " + " to Dataset - use v1 sink") { - testCurrentTimestampOnStreamingQuery(useV2Sink = false) + testCurrentTimestampOnStreamingQuery() } - private def testCurrentTimestampOnStreamingQuery(useV2Sink: Boolean): Unit = { + private def testCurrentTimestampOnStreamingQuery(): Unit = { val input = MemoryStream[Int] val df = input.toDS().withColumn("cur_timestamp", lit(current_timestamp())) @@ -1109,7 +1109,7 @@ class StreamSuite extends StreamTest { var lastTimestamp = System.currentTimeMillis() val currentDate = DateTimeUtils.millisToDays(lastTimestamp) - testStream(df, useV2Sink = useV2Sink) ( + testStream(df) ( AddData(input, 1), CheckLastBatch { rows: Seq[Row] => lastTimestamp = assertBatchOutputAndUpdateLastTimestamp(rows, lastTimestamp, currentDate, 1) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index a8efe5b4e889e..69889638e7617 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.streaming -import java.lang.Thread.UncaughtExceptionHandler - import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.experimental.macros @@ -42,7 +40,7 @@ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relation import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} -import org.apache.spark.sql.execution.streaming.sources.MemorySinkV2 +import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.execution.streaming.state.StateStore import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext @@ -86,7 +84,6 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } protected val defaultTrigger = Trigger.ProcessingTime(0) - protected val defaultUseV2Sink = false /** How long to wait for an active stream to catch up when checking a result. */ val streamingTimeout = 10.seconds @@ -327,8 +324,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be */ def testStream( _stream: Dataset[_], - outputMode: OutputMode = OutputMode.Append, - useV2Sink: Boolean = defaultUseV2Sink)(actions: StreamAction*): Unit = synchronized { + outputMode: OutputMode = OutputMode.Append)(actions: StreamAction*): Unit = synchronized { import org.apache.spark.sql.streaming.util.StreamManualClock // `synchronized` is added to prevent the user from calling multiple `testStream`s concurrently @@ -341,7 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var currentStream: StreamExecution = null var lastStream: StreamExecution = null val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for - val sink = if (useV2Sink) new MemorySinkV2 else new MemorySink(stream.schema, outputMode) + val sink = new MemorySink val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -394,10 +390,8 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be } def testState = { - val sinkDebugString = sink match { - case s: MemorySink => s.toDebugString - case s: MemorySinkV2 => s.toDebugString - } + val sinkDebugString = sink.toDebugString + s""" |== Progress == |$testActions diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 97dbb9b0360ec..01b77f30ff922 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -32,7 +32,14 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ +<<<<<<< HEAD import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} +||||||| parent of fb6b19ab7c... [SPARK-23014][SS] Fully remove V1 memory sink. +import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager +======= +import org.apache.spark.sql.execution.streaming.sources.MemorySink +import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager +>>>>>>> fb6b19ab7c... [SPARK-23014][SS] Fully remove V1 memory sink. import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index d00f2e3bf4d1a..5351d9cf7f190 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -180,7 +180,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { val listeners = (1 to 5).map(_ => new EventCollector) try { listeners.foreach(listener => spark.streams.addListener(listener)) - testStream(df, OutputMode.Append, useV2Sink = true)( + testStream(df, OutputMode.Append)( StartStream(Trigger.Continuous(1000)), StopStream, AssertOnQuery { query => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 729173cb7104f..ec0be40528a45 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -22,7 +22,7 @@ import java.util.concurrent.CountDownLatch import scala.collection.mutable -import org.apache.commons.io.{FileUtils, IOUtils} +import org.apache.commons.io.FileUtils import org.apache.commons.lang3.RandomStringUtils import org.apache.hadoop.fs.Path import org.scalactic.TolerantNumerics @@ -30,13 +30,13 @@ import org.scalatest.BeforeAndAfter import org.scalatest.concurrent.PatienceConfiguration.Timeout import org.scalatest.mockito.MockitoSugar -import org.apache.spark.SparkException +import org.apache.spark.{SparkException, TestUtils} import org.apache.spark.internal.Logging import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.catalyst.expressions.{Literal, Rand, Randn, Shuffle, Uuid} import org.apache.spark.sql.execution.exchange.ReusedExchangeExec import org.apache.spark.sql.execution.streaming._ -import org.apache.spark.sql.execution.streaming.sources.TestForeachWriter +import org.apache.spark.sql.execution.streaming.sources.{MemorySink, TestForeachWriter} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.reader.InputPartition @@ -498,7 +498,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi test("input row calculation with same V2 source used twice in self-union") { val streamInput = MemoryStream[Int] - testStream(streamInput.toDF().union(streamInput.toDF()), useV2Sink = true)( + testStream(streamInput.toDF().union(streamInput.toDF()))( AddData(streamInput, 1, 2, 3), CheckAnswer(1, 1, 2, 2, 3, 3), AssertOnQuery { q => @@ -519,7 +519,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi // relation, which breaks exchange reuse, as the optimizer will remove Project from one side. // Here we manually add a useful Project, to trigger exchange reuse. val streamDF = memoryStream.toDF().select('value + 0 as "v") - testStream(streamDF.join(streamDF, "v"), useV2Sink = true)( + testStream(streamDF.join(streamDF, "v"))( AddData(memoryStream, 1, 2, 3), CheckAnswer(1, 2, 3), check @@ -556,7 +556,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput1 = MemoryStream[Int] val streamInput2 = MemoryStream[Int] - testStream(streamInput1.toDF().union(streamInput2.toDF()), useV2Sink = true)( + testStream(streamInput1.toDF().union(streamInput2.toDF()))( AddData(streamInput1, 1, 2, 3), CheckLastBatch(1, 2, 3), AssertOnQuery { q => @@ -587,7 +587,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput = MemoryStream[Int] val staticInputDF = spark.createDataFrame(Seq(1 -> "1", 2 -> "2")).toDF("value", "anotherValue") - testStream(streamInput.toDF().join(staticInputDF, "value"), useV2Sink = true)( + testStream(streamInput.toDF().join(staticInputDF, "value"))( AddData(streamInput, 1, 2, 3), AssertOnQuery { q => q.processAllAvailable() @@ -609,7 +609,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi val streamInput2 = MemoryStream[Int] val staticInputDF2 = staticInputDF.union(staticInputDF).cache() - testStream(streamInput2.toDF().join(staticInputDF2, "value"), useV2Sink = true)( + testStream(streamInput2.toDF().join(staticInputDF2, "value"))( AddData(streamInput2, 1, 2, 3), AssertOnQuery { q => q.processAllAvailable() @@ -717,8 +717,8 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi q3.processAllAvailable() } assert(e.getCause.isInstanceOf[SparkException]) - assert(e.getCause.getCause.isInstanceOf[IllegalStateException]) - assert(e.getMessage.contains("StreamingQuery cannot be used in executors")) + assert(e.getCause.getCause.getCause.isInstanceOf[IllegalStateException]) + TestUtils.assertExceptionMsg(e, "StreamingQuery cannot be used in executors") } finally { q1.stop() q2.stop() @@ -912,7 +912,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(trigger = Trigger.Continuous(100)), AssertOnQuery(_.logicalPlan.toJSON.contains("StreamingDataSourceV2Relation")) ) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala index 10bea7f090571..59d6ac0af52a3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousQueryStatusAndProgressSuite.scala @@ -34,7 +34,7 @@ class ContinuousQueryStatusAndProgressSuite extends ContinuousSuiteBase { } val trigger = Trigger.Continuous(100) - testStream(input.toDF(), useV2Sink = true)( + testStream(input.toDF())( StartStream(trigger), Execute(assertStatus), AddData(input, 0, 1, 2), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala index d2e489a7d4ad2..9840c7f066780 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/continuous/ContinuousSuite.scala @@ -57,7 +57,6 @@ class ContinuousSuiteBase extends StreamTest { protected val longContinuousTrigger = Trigger.Continuous("1 hour") override protected val defaultTrigger = Trigger.Continuous(100) - override protected val defaultUseV2Sink = true } class ContinuousSuite extends ContinuousSuiteBase { @@ -239,7 +238,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(longContinuousTrigger), AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 10)), @@ -257,7 +256,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(2012)), AwaitEpoch(0), Execute(waitForRateSourceTriggers(_, 10)), @@ -274,7 +273,7 @@ class ContinuousStressSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(1012)), AwaitEpoch(2), StopStream, @@ -365,7 +364,7 @@ class ContinuousEpochBacklogSuite extends ContinuousSuiteBase { .load() .select('value) - testStream(df, useV2Sink = true)( + testStream(df)( StartStream(Trigger.Continuous(1)), ExpectFailure[IllegalStateException] { e => e.getMessage.contains("queue has exceeded its maximum") From 287e9d72661fa21da2e34564fded2154f484a2eb Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:04:36 -0700 Subject: [PATCH 58/70] Fix merge conflicts --- .../spark/sql/streaming/StreamingAggregationSuite.scala | 8 +------- 1 file changed, 1 insertion(+), 7 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index 01b77f30ff922..b2121bc955711 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -32,14 +32,8 @@ import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ -<<<<<<< HEAD -import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} -||||||| parent of fb6b19ab7c... [SPARK-23014][SS] Fully remove V1 memory sink. -import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager -======= import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager ->>>>>>> fb6b19ab7c... [SPARK-23014][SS] Fully remove V1 memory sink. +import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From d97de74b8d02e0e62bf336028ca5f0a3a0cd040b Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 6 May 2019 20:41:57 +0800 Subject: [PATCH 59/70] [SPARK-27579][SQL] remove BaseStreamingSource and BaseStreamingSink ## What changes were proposed in this pull request? `BaseStreamingSource` and `BaseStreamingSink` is used to unify v1 and v2 streaming data source API in some code paths. This PR removes these 2 interfaces, and let the v1 API extend v2 API to keep API compatibility. The motivation is https://github.com/apache/spark/pull/24416 . We want to move data source v2 to catalyst module, but `BaseStreamingSource` and `BaseStreamingSink` are in sql/core. ## How was this patch tested? existing tests Closes #24471 from cloud-fan/streaming. Authored-by: Wenchen Fan Signed-off-by: Wenchen Fan --- .../sql/kafka010/KafkaSourceProvider.scala | 4 +-- .../kafka010/KafkaMicroBatchSourceSuite.scala | 5 ++-- .../v2/reader/streaming/SparkDataStream.java | 8 +++-- .../streaming/BaseStreamingSink.java | 27 ----------------- .../streaming/BaseStreamingSource.java | 29 ------------------- .../datasources/noop/NoopDataSource.scala | 3 +- .../streaming/MicroBatchExecution.scala | 8 ++--- .../sql/execution/streaming/OffsetSeq.scala | 3 +- .../streaming/ProgressReporter.scala | 21 +++++++------- .../spark/sql/execution/streaming/Sink.scala | 21 +++++++++++++- .../sql/execution/streaming/Source.scala | 20 +++++++++++-- .../execution/streaming/StreamExecution.scala | 9 +++--- .../execution/streaming/StreamProgress.scala | 20 +++++++------ .../streaming/StreamingRelation.scala | 3 +- .../sql/execution/streaming/console.scala | 2 +- .../continuous/ContinuousExecution.scala | 2 +- .../sql/execution/streaming/memory.scala | 24 ++++++++------- .../sources/ForeachWriterTable.scala | 3 +- .../execution/streaming/sources/memory.scala | 4 +-- .../sql/streaming/DataStreamWriter.scala | 4 +-- .../sql/streaming/StreamingQueryManager.scala | 6 ++-- .../sources/RateStreamProviderSuite.scala | 4 +-- .../sources/TextSocketStreamSuite.scala | 4 +-- .../spark/sql/streaming/StreamTest.scala | 5 ++-- .../sources/StreamingDataSourceV2Suite.scala | 6 ++-- 25 files changed, 118 insertions(+), 127 deletions(-) delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java delete mode 100644 sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 88363d33525ee..0b661b7eeaf08 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -28,7 +28,7 @@ import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySe import org.apache.spark.internal.Logging import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} -import org.apache.spark.sql.execution.streaming.{BaseStreamingSink, Sink, Source} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.TableCapability._ @@ -353,7 +353,7 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister } class KafkaTable(strategy: => ConsumerStrategy) extends Table - with SupportsRead with SupportsWrite with BaseStreamingSink { + with SupportsRead with SupportsWrite { override def name(): String = s"Kafka $strategy" diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index 21634ae2abfa1..d2503a219a16e 100644 --- a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -41,6 +41,7 @@ import org.apache.spark.sql.execution.streaming.continuous.ContinuousExecution import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.kafka010.KafkaSourceProvider._ +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.streaming.{StreamTest, Trigger} import org.apache.spark.sql.streaming.util.StreamManualClock import org.apache.spark.sql.test.SharedSQLContext @@ -94,7 +95,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf message: String = "", topicAction: (String, Option[Int]) => Unit = (_, _) => {}) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { query match { // Make sure no Spark job is running when deleting a topic case Some(m: MicroBatchExecution) => m.processAllAvailable() @@ -114,7 +115,7 @@ abstract class KafkaSourceTest extends StreamTest with SharedSQLContext with Kaf query.nonEmpty, "Cannot add data when there is no query for finding the active kafka source") - val sources: Seq[BaseStreamingSource] = { + val sources: Seq[SparkDataStream] = { query.get.logicalPlan.collect { case StreamingExecutionRelation(source: KafkaSource, _) => source case r: StreamingDataSourceV2Relation if r.stream.isInstanceOf[KafkaMicroBatchStream] || diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java index 30f38ce37c401..2068a84fc6bb1 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/SparkDataStream.java @@ -18,7 +18,6 @@ package org.apache.spark.sql.sources.v2.reader.streaming; import org.apache.spark.annotation.Evolving; -import org.apache.spark.sql.execution.streaming.BaseStreamingSource; /** * The base interface representing a readable data stream in a Spark streaming query. It's @@ -28,7 +27,7 @@ * {@link MicroBatchStream} and {@link ContinuousStream}. */ @Evolving -public interface SparkDataStream extends BaseStreamingSource { +public interface SparkDataStream { /** * Returns the initial offset for a streaming query to start reading from. Note that the @@ -50,4 +49,9 @@ public interface SparkDataStream extends BaseStreamingSource { * equal to `end` and will only request offsets greater than `end` in the future. */ void commit(Offset end); + + /** + * Stop this source and free any resources it has allocated. + */ + void stop(); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java deleted file mode 100644 index ac96c2765368f..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSink.java +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming; - -/** - * The shared interface between V1 and V2 streaming sinks. - * - * This is a temporary interface for compatibility during migration. It should not be implemented - * directly, and will be removed in future versions. - */ -public interface BaseStreamingSink { -} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java deleted file mode 100644 index c44b8af2552f0..0000000000000 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/BaseStreamingSource.java +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.execution.streaming; - -/** - * The shared interface between V1 streaming sources and V2 streaming readers. - * - * This is a temporary interface for compatibility during migration. It should not be implemented - * directly, and will be removed in future versions. - */ -public interface BaseStreamingSource { - /** Stop this source and free any resources it has allocated. */ - void stop(); -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala index 321d006986f1d..2d90fd594fa7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/noop/NoopDataSource.scala @@ -22,7 +22,6 @@ import java.util import scala.collection.JavaConverters._ import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.streaming.BaseStreamingSink import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.writer._ @@ -39,7 +38,7 @@ class NoopDataSource extends TableProvider with DataSourceRegister { override def getTable(options: CaseInsensitiveStringMap): Table = NoopTable } -private[noop] object NoopTable extends Table with SupportsWrite with BaseStreamingSink { +private[noop] object NoopTable extends Table with SupportsWrite { override def newWriteBuilder(options: CaseInsensitiveStringMap): WriteBuilder = NoopWriteBuilder override def name(): String = "noop-table" override def schema(): StructType = new StructType() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index d9fe836b1c494..58c265d0a8501 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.v2.{StreamingDataSourceV2Relat import org.apache.spark.sql.execution.streaming.sources.{RateControlMicroBatchStream, WriteToMicroBatchDataSource} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ -import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.streaming.{OutputMode, ProcessingTime, Trigger} import org.apache.spark.util.Clock @@ -38,7 +38,7 @@ class MicroBatchExecution( name: String, checkpointRoot: String, analyzedPlan: LogicalPlan, - sink: BaseStreamingSink, + sink: Table, trigger: Trigger, triggerClock: Clock, outputMode: OutputMode, @@ -48,7 +48,7 @@ class MicroBatchExecution( sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { - @volatile protected var sources: Seq[BaseStreamingSource] = Seq.empty + @volatile protected var sources: Seq[SparkDataStream] = Seq.empty private val triggerExecutor = trigger match { case t: ProcessingTime => ProcessingTimeExecutor(t, triggerClock) @@ -354,7 +354,7 @@ class MicroBatchExecution( if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[BaseStreamingSource, Option[Offset]] = uniqueSources.map { + val latestOffsets: Map[SparkDataStream, Option[Offset]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 73cf355dbe758..0f7ad7517e8fe 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -24,6 +24,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream /** * An ordered collection of offsets, used to track the progress of processing data from one or more @@ -39,7 +40,7 @@ case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMet * This method is typically used to associate a serialized offset with actual sources (which * cannot be serialized). */ - def toStreamProgress(sources: Seq[BaseStreamingSource]): StreamProgress = { + def toStreamProgress(sources: Seq[SparkDataStream]): StreamProgress = { assert(sources.size == offsets.size, s"There are [${offsets.size}] sources in the " + s"checkpoint offsets and now there are [${sources.size}] sources requested by the query. " + s"Cannot continue.") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala index 25283515b882f..932daef8965d3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/ProgressReporter.scala @@ -29,7 +29,8 @@ import org.apache.spark.sql.catalyst.plans.logical.{EventTimeWatermark, LogicalP import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.datasources.v2.{MicroBatchScanExec, StreamingDataSourceV2Relation, StreamWriterCommitProgress} -import org.apache.spark.sql.sources.v2.reader.streaming.MicroBatchStream +import org.apache.spark.sql.sources.v2.Table +import org.apache.spark.sql.sources.v2.reader.streaming.{MicroBatchStream, SparkDataStream} import org.apache.spark.sql.streaming._ import org.apache.spark.sql.streaming.StreamingQueryListener.QueryProgressEvent import org.apache.spark.util.Clock @@ -44,7 +45,7 @@ import org.apache.spark.util.Clock trait ProgressReporter extends Logging { case class ExecutionStats( - inputRows: Map[BaseStreamingSource, Long], + inputRows: Map[SparkDataStream, Long], stateOperators: Seq[StateOperatorProgress], eventTimeStats: Map[String, String]) @@ -55,10 +56,10 @@ trait ProgressReporter extends Logging { protected def triggerClock: Clock protected def logicalPlan: LogicalPlan protected def lastExecution: QueryExecution - protected def newData: Map[BaseStreamingSource, LogicalPlan] + protected def newData: Map[SparkDataStream, LogicalPlan] protected def sinkCommitProgress: Option[StreamWriterCommitProgress] - protected def sources: Seq[BaseStreamingSource] - protected def sink: BaseStreamingSink + protected def sources: Seq[SparkDataStream] + protected def sink: Table protected def offsetSeqMetadata: OffsetSeqMetadata protected def currentBatchId: Long protected def sparkSession: SparkSession @@ -67,8 +68,8 @@ trait ProgressReporter extends Logging { // Local timestamps and counters. private var currentTriggerStartTimestamp = -1L private var currentTriggerEndTimestamp = -1L - private var currentTriggerStartOffsets: Map[BaseStreamingSource, String] = _ - private var currentTriggerEndOffsets: Map[BaseStreamingSource, String] = _ + private var currentTriggerStartOffsets: Map[SparkDataStream, String] = _ + private var currentTriggerEndOffsets: Map[SparkDataStream, String] = _ // TODO: Restore this from the checkpoint when possible. private var lastTriggerStartTimestamp = -1L @@ -240,9 +241,9 @@ trait ProgressReporter extends Logging { } /** Extract number of input sources for each streaming source in plan */ - private def extractSourceToNumInputRows(): Map[BaseStreamingSource, Long] = { + private def extractSourceToNumInputRows(): Map[SparkDataStream, Long] = { - def sumRows(tuples: Seq[(BaseStreamingSource, Long)]): Map[BaseStreamingSource, Long] = { + def sumRows(tuples: Seq[(SparkDataStream, Long)]): Map[SparkDataStream, Long] = { tuples.groupBy(_._1).mapValues(_.map(_._2).sum) // sum up rows for each source } @@ -262,7 +263,7 @@ trait ProgressReporter extends Logging { val sourceToInputRowsTuples = lastExecution.executedPlan.collect { case s: MicroBatchScanExec => val numRows = s.metrics.get("numOutputRows").map(_.value).getOrElse(0L) - val source = s.stream.asInstanceOf[BaseStreamingSource] + val source = s.stream source -> numRows } logDebug("Source -> # input rows\n\t" + sourceToInputRowsTuples.mkString("\n\t")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala index 34bc085d920c1..190325fb7ec25 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Sink.scala @@ -17,14 +17,21 @@ package org.apache.spark.sql.execution.streaming +import java.util + import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sources.v2.{Table, TableCapability} +import org.apache.spark.sql.types.StructType /** * An interface for systems that can collect the results of a streaming query. In order to preserve * exactly once semantics a sink must be idempotent in the face of multiple attempts to add the same * batch. + * + * Note that, we extends `Table` here, to make the v1 streaming sink API be compatible with + * data source v2. */ -trait Sink extends BaseStreamingSink { +trait Sink extends Table { /** * Adds a batch of data to this sink. The data for a given `batchId` is deterministic and if @@ -38,4 +45,16 @@ trait Sink extends BaseStreamingSink { * after data is consumed by sink successfully. */ def addBatch(batchId: Long, data: DataFrame): Unit + + override def name: String = { + throw new IllegalStateException("should not be called.") + } + + override def schema: StructType = { + throw new IllegalStateException("should not be called.") + } + + override def capabilities: util.Set[TableCapability] = { + throw new IllegalStateException("should not be called.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala index dbbd59e06909c..7f66d0b055cc3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Source.scala @@ -18,14 +18,19 @@ package org.apache.spark.sql.execution.streaming import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.types.StructType /** * A source of continually arriving data for a streaming query. A [[Source]] must have a * monotonically increasing notion of progress that can be represented as an [[Offset]]. Spark * will regularly query each [[Source]] to see if any more data is available. + * + * Note that, we extends `SparkDataStream` here, to make the v1 streaming source API be compatible + * with data source v2. */ -trait Source extends BaseStreamingSource { +trait Source extends SparkDataStream { /** Returns the schema of the data from this source */ def schema: StructType @@ -62,6 +67,15 @@ trait Source extends BaseStreamingSource { */ def commit(end: Offset) : Unit = {} - /** Stop this source and free any resources it has allocated. */ - def stop(): Unit + override def initialOffset(): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def deserializeOffset(json: String): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def commit(end: OffsetV2): Unit = { + throw new IllegalStateException("should not be called.") + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index fd959619650e4..5d66b61ae7111 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -40,7 +40,8 @@ import org.apache.spark.sql.execution.QueryExecution import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.SupportsWrite +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.sources.v2.writer.SupportsTruncate import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ @@ -69,7 +70,7 @@ abstract class StreamExecution( override val name: String, private val checkpointRoot: String, analyzedPlan: LogicalPlan, - val sink: BaseStreamingSink, + val sink: Table, val trigger: Trigger, val triggerClock: Clock, val outputMode: OutputMode, @@ -205,7 +206,7 @@ abstract class StreamExecution( /** * A list of unique sources in the query plan. This will be set when generating logical plan. */ - @volatile protected var uniqueSources: Seq[BaseStreamingSource] = Seq.empty + @volatile protected var uniqueSources: Seq[SparkDataStream] = Seq.empty /** Defines the internal state of execution */ protected val state = new AtomicReference[State](INITIALIZING) @@ -214,7 +215,7 @@ abstract class StreamExecution( var lastExecution: IncrementalExecution = _ /** Holds the most recent input data for each source. */ - protected var newData: Map[BaseStreamingSource, LogicalPlan] = _ + protected var newData: Map[SparkDataStream, LogicalPlan] = _ @volatile protected var streamDeathCause: StreamingQueryException = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 8531070b1bc49..8a1d064f49d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -19,32 +19,34 @@ package org.apache.spark.sql.execution.streaming import scala.collection.{immutable, GenTraversableOnce} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream + /** * A helper class that looks like a Map[Source, Offset]. */ class StreamProgress( - val baseMap: immutable.Map[BaseStreamingSource, Offset] = - new immutable.HashMap[BaseStreamingSource, Offset]) - extends scala.collection.immutable.Map[BaseStreamingSource, Offset] { + val baseMap: immutable.Map[SparkDataStream, Offset] = + new immutable.HashMap[SparkDataStream, Offset]) + extends scala.collection.immutable.Map[SparkDataStream, Offset] { - def toOffsetSeq(source: Seq[BaseStreamingSource], metadata: OffsetSeqMetadata): OffsetSeq = { + def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) } override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - override def +[B1 >: Offset](kv: (BaseStreamingSource, B1)): Map[BaseStreamingSource, B1] = { + override def +[B1 >: Offset](kv: (SparkDataStream, B1)): Map[SparkDataStream, B1] = { baseMap + kv } - override def get(key: BaseStreamingSource): Option[Offset] = baseMap.get(key) + override def get(key: SparkDataStream): Option[Offset] = baseMap.get(key) - override def iterator: Iterator[(BaseStreamingSource, Offset)] = baseMap.iterator + override def iterator: Iterator[(SparkDataStream, Offset)] = baseMap.iterator - override def -(key: BaseStreamingSource): Map[BaseStreamingSource, Offset] = baseMap - key + override def -(key: SparkDataStream): Map[SparkDataStream, Offset] = baseMap - key - def ++(updates: GenTraversableOnce[(BaseStreamingSource, Offset)]): StreamProgress = { + def ++(updates: GenTraversableOnce[(SparkDataStream, Offset)]): StreamProgress = { new StreamProgress(baseMap ++ updates) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala index 0d7e9ba363d01..142b6e7d18068 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamingRelation.scala @@ -26,6 +26,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, LogicalPlan, Stati import org.apache.spark.sql.execution.LeafExecNode import org.apache.spark.sql.execution.datasources.DataSource import org.apache.spark.sql.sources.v2.{Table, TableProvider} +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.util.CaseInsensitiveStringMap object StreamingRelation { @@ -63,7 +64,7 @@ case class StreamingRelation(dataSource: DataSource, sourceName: String, output: * [[org.apache.spark.sql.catalyst.plans.logical.LogicalPlan]]. */ case class StreamingExecutionRelation( - source: BaseStreamingSource, + source: SparkDataStream, output: Seq[Attribute])(session: SparkSession) extends LeafNode with MultiInstanceRelation { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala index c7161d311c028..9ae39c79c5156 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/console.scala @@ -61,7 +61,7 @@ class ConsoleSinkProvider extends TableProvider def shortName(): String = "console" } -object ConsoleTable extends Table with SupportsWrite with BaseStreamingSink { +object ConsoleTable extends Table with SupportsWrite { override def name(): String = "console" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala index ef0c942e959ea..5475becc5bff4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/continuous/ContinuousExecution.scala @@ -49,7 +49,7 @@ class ContinuousExecution( extraOptions: Map[String, String], deleteCheckpointOnStop: Boolean) extends StreamExecution( - sparkSession, name, checkpointRoot, analyzedPlan, sink.asInstanceOf[BaseStreamingSink], + sparkSession, name, checkpointRoot, analyzedPlan, sink, trigger, triggerClock, outputMode, deleteCheckpointOnStop) { @volatile protected var sources: Seq[ContinuousStream] = Seq() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 6efde0a27efe9..022c8da0c074e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.util.truncatedString import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2._ import org.apache.spark.sql.sources.v2.reader._ -import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2} +import org.apache.spark.sql.sources.v2.reader.streaming.{ContinuousStream, MicroBatchStream, Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap @@ -49,7 +49,7 @@ object MemoryStream { /** * A base class for memory stream implementations. Supports adding data and resetting. */ -abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends BaseStreamingSource { +abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream { val encoder = encoderFor[A] protected val attributes = encoder.schema.toAttributes @@ -78,6 +78,18 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Bas } def addData(data: TraversableOnce[A]): Offset + + override def initialOffset(): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def deserializeOffset(json: String): OffsetV2 = { + throw new IllegalStateException("should not be called.") + } + + override def commit(end: OffsetV2): Unit = { + throw new IllegalStateException("should not be called.") + } } // This class is used to indicate the memory stream data source. We don't actually use it, as @@ -264,11 +276,3 @@ object MemoryStreamReaderFactory extends PartitionReaderFactory { } } } - -/** A common trait for MemorySinks with methods used for testing */ -trait MemorySinkBase extends BaseStreamingSink { - def allData: Seq[Row] - def latestBatchData: Seq[Row] - def dataSinceBatch(sinceBatchId: Long): Seq[Row] - def latestBatchId: Option[Long] -} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala index 838ede6c563f2..6da1b3a49c442 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ForeachWriterTable.scala @@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.execution.python.PythonForeachWriter -import org.apache.spark.sql.execution.streaming.BaseStreamingSink import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer.{DataWriter, SupportsTruncate, WriteBuilder, WriterCommitMessage} import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} @@ -44,7 +43,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap case class ForeachWriterTable[T]( writer: ForeachWriter[T], converter: Either[ExpressionEncoder[T], InternalRow => T]) - extends Table with SupportsWrite with BaseStreamingSink { + extends Table with SupportsWrite { override def name(): String = "ForeachSink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala index 9008c63491cb6..de8d00d4ac348 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/memory.scala @@ -32,7 +32,7 @@ import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.{LeafNode, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils -import org.apache.spark.sql.execution.streaming.{MemorySinkBase, Sink} +import org.apache.spark.sql.execution.streaming.Sink import org.apache.spark.sql.sources.v2.{SupportsWrite, Table, TableCapability} import org.apache.spark.sql.sources.v2.writer._ import org.apache.spark.sql.sources.v2.writer.streaming.{StreamingDataWriterFactory, StreamingWrite} @@ -43,7 +43,7 @@ import org.apache.spark.sql.util.CaseInsensitiveStringMap * A sink that stores the results in memory. This [[Sink]] is primarily intended for use in unit * tests and does not provide durability. */ -class MemorySink extends Table with SupportsWrite with MemorySinkBase with Logging { +class MemorySink extends Table with SupportsWrite with Logging { override def name(): String = "MemorySink" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala index 2f12efe04c507..d051cf9c1d4a1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/DataStreamWriter.scala @@ -311,7 +311,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Implicits._ provider.getTable(dsOptions) match { case table: SupportsWrite if table.supports(STREAMING_WRITE) => - table.asInstanceOf[BaseStreamingSink] + table case _ => createV1Sink() } } else { @@ -331,7 +331,7 @@ final class DataStreamWriter[T] private[sql](ds: Dataset[T]) { } } - private def createV1Sink(): BaseStreamingSink = { + private def createV1Sink(): Sink = { val ds = DataSource( df.sparkSession, className = source, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala index 040f1723fb93b..63fb9ed176b9f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/streaming/StreamingQueryManager.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, import org.apache.spark.sql.execution.streaming.state.StateStoreCoordinatorRef import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.StaticSQLConf.STREAMING_QUERY_LISTENERS -import org.apache.spark.sql.sources.v2.SupportsWrite +import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} import org.apache.spark.util.{Clock, SystemClock, Utils} /** @@ -206,7 +206,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedCheckpointLocation: Option[String], df: DataFrame, extraOptions: Map[String, String], - sink: BaseStreamingSink, + sink: Table, outputMode: OutputMode, useTempCheckpointLocation: Boolean, recoverFromCheckpointLocation: Boolean, @@ -312,7 +312,7 @@ class StreamingQueryManager private[sql] (sparkSession: SparkSession) extends Lo userSpecifiedCheckpointLocation: Option[String], df: DataFrame, extraOptions: Map[String, String], - sink: BaseStreamingSink, + sink: Table, outputMode: OutputMode, useTempCheckpointLocation: Boolean = false, recoverFromCheckpointLocation: Boolean = true, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala index b024f957020a7..ef88598fcb11b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/RateStreamProviderSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.util.CaseInsensitiveStringMap import org.apache.spark.util.ManualClock @@ -39,7 +39,7 @@ class RateStreamProviderSuite extends StreamTest { import testImplicits._ case class AdvanceRateManualClock(seconds: Long) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { assert(query.nonEmpty) val rateSource = query.get.logicalPlan.collect { case r: StreamingDataSourceV2Relation diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala index 956339355de48..3c451e0538721 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketStreamSuite.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.v2.StreamingDataSourceV2Relati import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous._ import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.sources.v2.reader.streaming.Offset +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset, SparkDataStream} import org.apache.spark.sql.streaming.{StreamingQueryException, StreamTest} import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.types._ @@ -55,7 +55,7 @@ class TextSocketStreamSuite extends StreamTest with SharedSQLContext with Before private var serverThread: ServerThread = null case class AddSocketData(data: String*) extends AddData { - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { require( query.nonEmpty, "Cannot add data when there is no query for finding the active socket source") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 69889638e7617..210d7300d95ab 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -42,6 +42,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.execution.streaming.state.StateStore +import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -123,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) + def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) } /** A trait that can be extended when testing a source. */ @@ -134,7 +135,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (BaseStreamingSource, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { (source, source.addData(data)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala index 25a68e4f9a57c..7b2c1a56e8baa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/sources/StreamingDataSourceV2Suite.scala @@ -22,9 +22,9 @@ import java.util.Collections import scala.collection.JavaConverters._ -import org.apache.spark.sql.{AnalysisException, DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.execution.datasources.DataSource -import org.apache.spark.sql.execution.streaming.{BaseStreamingSink, RateStreamOffset, Sink, StreamingQueryWrapper} +import org.apache.spark.sql.execution.streaming.{RateStreamOffset, Sink, StreamingQueryWrapper} import org.apache.spark.sql.execution.streaming.continuous.ContinuousTrigger import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider} @@ -80,7 +80,7 @@ class FakeWriteBuilder extends WriteBuilder with StreamingWrite { } } -trait FakeStreamingWriteTable extends Table with SupportsWrite with BaseStreamingSink { +trait FakeStreamingWriteTable extends Table with SupportsWrite { override def name(): String = "fake" override def schema(): StructType = StructType(Seq()) override def capabilities(): util.Set[TableCapability] = { From 12746c113f050a67d56a8e903b01740062aa06c5 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 7 May 2019 23:03:15 -0700 Subject: [PATCH 60/70] [SPARK-27642][SS] make v1 offset extends v2 offset ## What changes were proposed in this pull request? To move DS v2 to the catalyst module, we can't make v2 offset rely on v1 offset, as v1 offset is in sql/core. ## How was this patch tested? existing tests Closes #24538 from cloud-fan/offset. Authored-by: Wenchen Fan Signed-off-by: gatorsmile --- .../sql/kafka010/KafkaContinuousStream.scala | 2 +- .../sql/kafka010/KafkaSourceOffset.scala | 4 +- .../sources/v2/reader/streaming/Offset.java | 11 ++--- .../spark/sql/execution/streaming/Offset.java | 42 ++----------------- .../sql/execution/streaming/LongOffset.scala | 14 +------ .../streaming/MicroBatchExecution.scala | 10 ++--- .../sql/execution/streaming/OffsetSeq.scala | 9 ++-- .../execution/streaming/OffsetSeqLog.scala | 3 +- .../execution/streaming/StreamExecution.scala | 4 +- .../execution/streaming/StreamProgress.scala | 19 +++++---- .../sql/execution/streaming/memory.scala | 25 ++++------- .../sources/TextSocketMicroBatchStream.scala | 5 +-- .../spark/sql/streaming/StreamTest.scala | 8 ++-- 13 files changed, 49 insertions(+), 107 deletions(-) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala index d60ee1cadd195..92686d24e2b8a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaContinuousStream.scala @@ -76,7 +76,7 @@ class KafkaContinuousStream( } override def planInputPartitions(start: Offset): Array[InputPartition] = { - val oldStartPartitionOffsets = KafkaSourceOffset.getPartitionOffsets(start) + val oldStartPartitionOffsets = start.asInstanceOf[KafkaSourceOffset].partitionToOffsets val currentPartitionSet = offsetReader.fetchEarliestOffsets().keySet val newPartitions = currentPartitionSet.diff(oldStartPartitionOffsets.keySet) diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala index 8d41c0da2b133..90d70439c5329 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceOffset.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.kafka010 import org.apache.kafka.common.TopicPartition import org.apache.spark.sql.execution.streaming.{Offset, SerializedOffset} -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, PartitionOffset} +import org.apache.spark.sql.sources.v2.reader.streaming.PartitionOffset /** * An [[Offset]] for the [[KafkaSource]]. This one tracks all partitions of subscribed topics and * their offsets. */ private[kafka010] -case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends OffsetV2 { +case class KafkaSourceOffset(partitionToOffsets: Map[TopicPartition, Long]) extends Offset { override val json = JsonUtils.partitionOffsets(partitionToOffsets) } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java index a06671383ac5f..1d34fdd1c28ab 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/sources/v2/reader/streaming/Offset.java @@ -25,13 +25,9 @@ * During execution, offsets provided by the data source implementation will be logged and used as * restart checkpoints. Each source should provide an offset implementation which the source can use * to reconstruct a position in the stream up to which data has been seen/processed. - * - * Note: This class currently extends {@link org.apache.spark.sql.execution.streaming.Offset} to - * maintain compatibility with DataSource V1 APIs. This extension will be removed once we - * get rid of V1 completely. */ @Evolving -public abstract class Offset extends org.apache.spark.sql.execution.streaming.Offset { +public abstract class Offset { /** * A JSON-serialized representation of an Offset that is * used for saving offsets to the offset log. @@ -49,9 +45,8 @@ public abstract class Offset extends org.apache.spark.sql.execution.streaming.Of */ @Override public boolean equals(Object obj) { - if (obj instanceof org.apache.spark.sql.execution.streaming.Offset) { - return this.json() - .equals(((org.apache.spark.sql.execution.streaming.Offset) obj).json()); + if (obj instanceof Offset) { + return this.json().equals(((Offset) obj).json()); } else { return false; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java index 43ad4b3384ec3..7c167dc012329 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/streaming/Offset.java @@ -18,44 +18,10 @@ package org.apache.spark.sql.execution.streaming; /** - * This is an internal, deprecated interface. New source implementations should use the - * org.apache.spark.sql.sources.v2.reader.streaming.Offset class, which is the one that will be - * supported in the long term. + * This class is an alias of {@link org.apache.spark.sql.sources.v2.reader.streaming.Offset}. It's + * internal and deprecated. New streaming data source implementations should use data source v2 API, + * which will be supported in the long term. * * This class will be removed in a future release. */ -public abstract class Offset { - /** - * A JSON-serialized representation of an Offset that is - * used for saving offsets to the offset log. - * Note: We assume that equivalent/equal offsets serialize to - * identical JSON strings. - * - * @return JSON string encoding - */ - public abstract String json(); - - /** - * Equality based on JSON string representation. We leverage the - * JSON representation for normalization between the Offset's - * in memory and on disk representations. - */ - @Override - public boolean equals(Object obj) { - if (obj instanceof Offset) { - return this.json().equals(((Offset) obj).json()); - } else { - return false; - } - } - - @Override - public int hashCode() { - return this.json().hashCode(); - } - - @Override - public String toString() { - return this.json(); - } -} +public abstract class Offset extends org.apache.spark.sql.sources.v2.reader.streaming.Offset {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala index 3ff5b86ac45d6..a27898cb0c9fc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/LongOffset.scala @@ -17,12 +17,10 @@ package org.apache.spark.sql.execution.streaming -import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} - /** * A simple offset for sources that produce a single linear stream of data. */ -case class LongOffset(offset: Long) extends OffsetV2 { +case class LongOffset(offset: Long) extends Offset { override val json = offset.toString @@ -37,14 +35,4 @@ object LongOffset { * @return new LongOffset */ def apply(offset: SerializedOffset) : LongOffset = new LongOffset(offset.json.toLong) - - /** - * Convert generic Offset to LongOffset if possible. - * @return converted LongOffset - */ - def convert(offset: Offset): Option[LongOffset] = offset match { - case lo: LongOffset => Some(lo) - case so: SerializedOffset => Some(LongOffset(so)) - case _ => None - } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala index 58c265d0a8501..7a3cdbc926446 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecution.scala @@ -296,7 +296,7 @@ class MicroBatchExecution( * batch will be executed before getOffset is called again. */ availableOffsets.foreach { case (source: Source, end: Offset) => - val start = committedOffsets.get(source) + val start = committedOffsets.get(source).map(_.asInstanceOf[Offset]) source.getBatch(start, end) case nonV1Tuple => // The V2 API does not have the same edge case requiring getBatch to be called @@ -354,7 +354,7 @@ class MicroBatchExecution( if (isCurrentBatchConstructed) return true // Generate a map from each unique source to the next available offset. - val latestOffsets: Map[SparkDataStream, Option[Offset]] = uniqueSources.map { + val latestOffsets: Map[SparkDataStream, Option[OffsetV2]] = uniqueSources.map { case s: Source => updateStatusMessage(s"Getting offsets from $s") reportTimeTaken("getOffset") { @@ -411,7 +411,7 @@ class MicroBatchExecution( val prevBatchOff = offsetLog.get(currentBatchId - 1) if (prevBatchOff.isDefined) { prevBatchOff.get.toStreamProgress(sources).foreach { - case (src: Source, off) => src.commit(off) + case (src: Source, off: Offset) => src.commit(off) case (stream: MicroBatchStream, off) => stream.commit(stream.deserializeOffset(off.json)) case (src, _) => @@ -448,9 +448,9 @@ class MicroBatchExecution( // Request unprocessed data from all sources. newData = reportTimeTaken("getBatch") { availableOffsets.flatMap { - case (source: Source, available) + case (source: Source, available: Offset) if committedOffsets.get(source).map(_ != available).getOrElse(true) => - val current = committedOffsets.get(source) + val current = committedOffsets.get(source).map(_.asInstanceOf[Offset]) val batch = source.getBatch(current, available) assert(batch.isStreaming, s"DataFrame returned by getBatch from $source did not have isStreaming=true\n" + diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala index 0f7ad7517e8fe..b6fa2e9dc3612 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeq.scala @@ -24,14 +24,15 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.RuntimeConfig import org.apache.spark.sql.execution.streaming.state.{FlatMapGroupsWithStateExecHelper, StreamingAggregationStateManager} import org.apache.spark.sql.internal.SQLConf.{FLATMAPGROUPSWITHSTATE_STATE_FORMAT_VERSION, _} -import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} + /** * An ordered collection of offsets, used to track the progress of processing data from one or more * [[Source]]s that are present in a streaming query. This is similar to simplified, single-instance * vector clock that must progress linearly forward. */ -case class OffsetSeq(offsets: Seq[Option[Offset]], metadata: Option[OffsetSeqMetadata] = None) { +case class OffsetSeq(offsets: Seq[Option[OffsetV2]], metadata: Option[OffsetSeqMetadata] = None) { /** * Unpacks an offset into [[StreamProgress]] by associating each offset with the ordered list of @@ -57,13 +58,13 @@ object OffsetSeq { * Returns a [[OffsetSeq]] with a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(offsets: Offset*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) + def fill(offsets: OffsetV2*): OffsetSeq = OffsetSeq.fill(None, offsets: _*) /** * Returns a [[OffsetSeq]] with metadata and a variable sequence of offsets. * `nulls` in the sequence are converted to `None`s. */ - def fill(metadata: Option[String], offsets: Offset*): OffsetSeq = { + def fill(metadata: Option[String], offsets: OffsetV2*): OffsetSeq = { OffsetSeq(offsets.map(Option(_)), metadata.map(OffsetSeqMetadata.apply)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala index 2c8d7c7b0f3c5..8a05dade092c8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/OffsetSeqLog.scala @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets._ import scala.io.{Source => IOSource} import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2} /** * This class is used to log offsets to persistent files in HDFS. @@ -47,7 +48,7 @@ class OffsetSeqLog(sparkSession: SparkSession, path: String) override protected def deserialize(in: InputStream): OffsetSeq = { // called inside a try-finally where the underlying stream is closed in the caller - def parseOffset(value: String): Offset = value match { + def parseOffset(value: String): OffsetV2 = value match { case OffsetSeqLog.SERIALIZED_VOID_OFFSET => null case json => SerializedOffset(json) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala index 5d66b61ae7111..4c08b3aa78666 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamExecution.scala @@ -41,7 +41,7 @@ import org.apache.spark.sql.execution.command.StreamingExplainCommand import org.apache.spark.sql.execution.datasources.v2.StreamWriterCommitProgress import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.v2.{SupportsWrite, Table} -import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.sources.v2.writer.SupportsTruncate import org.apache.spark.sql.sources.v2.writer.streaming.StreamingWrite import org.apache.spark.sql.streaming._ @@ -438,7 +438,7 @@ abstract class StreamExecution( * Blocks the current thread until processing for data from the given `source` has reached at * least the given `Offset`. This method is intended for use primarily when writing tests. */ - private[sql] def awaitOffset(sourceIndex: Int, newOffset: Offset, timeoutMs: Long): Unit = { + private[sql] def awaitOffset(sourceIndex: Int, newOffset: OffsetV2, timeoutMs: Long): Unit = { assertAwaitThread() def notDone = { val localCommittedOffsets = committedOffsets diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala index 8a1d064f49d1c..8783eaa0e68b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/StreamProgress.scala @@ -19,15 +19,16 @@ package org.apache.spark.sql.execution.streaming import scala.collection.{immutable, GenTraversableOnce} -import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} + /** * A helper class that looks like a Map[Source, Offset]. */ class StreamProgress( - val baseMap: immutable.Map[SparkDataStream, Offset] = - new immutable.HashMap[SparkDataStream, Offset]) - extends scala.collection.immutable.Map[SparkDataStream, Offset] { + val baseMap: immutable.Map[SparkDataStream, OffsetV2] = + new immutable.HashMap[SparkDataStream, OffsetV2]) + extends scala.collection.immutable.Map[SparkDataStream, OffsetV2] { def toOffsetSeq(source: Seq[SparkDataStream], metadata: OffsetSeqMetadata): OffsetSeq = { OffsetSeq(source.map(get), Some(metadata)) @@ -36,17 +37,17 @@ class StreamProgress( override def toString: String = baseMap.map { case (k, v) => s"$k: $v"}.mkString("{", ",", "}") - override def +[B1 >: Offset](kv: (SparkDataStream, B1)): Map[SparkDataStream, B1] = { + override def +[B1 >: OffsetV2](kv: (SparkDataStream, B1)): Map[SparkDataStream, B1] = { baseMap + kv } - override def get(key: SparkDataStream): Option[Offset] = baseMap.get(key) + override def get(key: SparkDataStream): Option[OffsetV2] = baseMap.get(key) - override def iterator: Iterator[(SparkDataStream, Offset)] = baseMap.iterator + override def iterator: Iterator[(SparkDataStream, OffsetV2)] = baseMap.iterator - override def -(key: SparkDataStream): Map[SparkDataStream, Offset] = baseMap - key + override def -(key: SparkDataStream): Map[SparkDataStream, OffsetV2] = baseMap - key - def ++(updates: GenTraversableOnce[(SparkDataStream, Offset)]): StreamProgress = { + def ++(updates: GenTraversableOnce[(SparkDataStream, OffsetV2)]): StreamProgress = { new StreamProgress(baseMap ++ updates) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala index 022c8da0c074e..df149552dfb30 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/memory.scala @@ -61,10 +61,12 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa Dataset.ofRows(sqlContext.sparkSession, logicalPlan) } - def addData(data: A*): Offset = { + def addData(data: A*): OffsetV2 = { addData(data.toTraversable) } + def addData(data: TraversableOnce[A]): OffsetV2 + def fullSchema(): StructType = encoder.schema protected val logicalPlan: LogicalPlan = { @@ -77,8 +79,6 @@ abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends Spa None)(sqlContext.sparkSession) } - def addData(data: TraversableOnce[A]): Offset - override def initialOffset(): OffsetV2 = { throw new IllegalStateException("should not be called.") } @@ -226,22 +226,15 @@ case class MemoryStream[A : Encoder](id: Int, sqlContext: SQLContext) } override def commit(end: OffsetV2): Unit = synchronized { - def check(newOffset: LongOffset): Unit = { - val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt + val newOffset = end.asInstanceOf[LongOffset] + val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt - if (offsetDiff < 0) { - sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") - } - - batches.trimStart(offsetDiff) - lastOffsetCommitted = newOffset + if (offsetDiff < 0) { + sys.error(s"Offsets committed out of order: $lastOffsetCommitted followed by $end") } - LongOffset.convert(end) match { - case Some(lo) => check(lo) - case None => sys.error(s"MemoryStream.commit() received an offset ($end) " + - "that did not originate with an instance of this class") - } + batches.trimStart(offsetDiff) + lastOffsetCommitted = newOffset } override def stop() {} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala index 9168d46493aef..dd8d89238008e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/TextSocketMicroBatchStream.scala @@ -153,10 +153,7 @@ class TextSocketMicroBatchStream(host: String, port: Int, numPartitions: Int) } override def commit(end: Offset): Unit = synchronized { - val newOffset = LongOffset.convert(end).getOrElse( - sys.error(s"TextSocketStream.commit() received an offset ($end) that did not " + - s"originate with an instance of this class") - ) + val newOffset = end.asInstanceOf[LongOffset] val offsetDiff = (newOffset.offset - lastOffsetCommitted.offset).toInt diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 210d7300d95ab..11acb534e9d69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.continuous.{ContinuousExecution, EpochCoordinatorRef, IncrementAndGetEpoch} import org.apache.spark.sql.execution.streaming.sources.MemorySink import org.apache.spark.sql.execution.streaming.state.StateStore -import org.apache.spark.sql.sources.v2.reader.streaming.SparkDataStream +import org.apache.spark.sql.sources.v2.reader.streaming.{Offset => OffsetV2, SparkDataStream} import org.apache.spark.sql.streaming.StreamingQueryListener._ import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.{Clock, SystemClock, Utils} @@ -124,7 +124,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be * the active query, and then return the source object the data was added, as well as the * offset of added data. */ - def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) + def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2) } /** A trait that can be extended when testing a source. */ @@ -135,7 +135,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be case class AddDataMemory[A](source: MemoryStreamBase[A], data: Seq[A]) extends AddData { override def toString: String = s"AddData to $source: ${data.mkString(",")}" - override def addData(query: Option[StreamExecution]): (SparkDataStream, Offset) = { + override def addData(query: Option[StreamExecution]): (SparkDataStream, OffsetV2) = { (source, source.addData(data)) } } @@ -337,7 +337,7 @@ trait StreamTest extends QueryTest with SharedSQLContext with TimeLimits with Be var pos = 0 var currentStream: StreamExecution = null var lastStream: StreamExecution = null - val awaiting = new mutable.HashMap[Int, Offset]() // source index -> offset to wait for + val awaiting = new mutable.HashMap[Int, OffsetV2]() // source index -> offset to wait for val sink = new MemorySink val resetConfValues = mutable.Map[String, Option[String]]() val defaultCheckpointLocation = From c99b896e5b96e72351f55bd325d29c2e9a265e0a Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:09:46 -0700 Subject: [PATCH 61/70] Fix imports --- .../spark/sql/streaming/StreamingAggregationSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala index b2121bc955711..3f304e9ec7788 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingAggregationSuite.scala @@ -21,7 +21,7 @@ import java.io.File import java.util.{Locale, TimeZone} import org.apache.commons.io.FileUtils -import org.scalatest.{Assertions, BeforeAndAfterAll} +import org.scalatest.Assertions import org.apache.spark.{SparkEnv, SparkException} import org.apache.spark.rdd.BlockRDD @@ -33,7 +33,7 @@ import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} import org.apache.spark.sql.execution.exchange.Exchange import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.execution.streaming.sources.MemorySink -import org.apache.spark.sql.execution.streaming.state.{StateStore, StreamingAggregationStateManager} +import org.apache.spark.sql.execution.streaming.state.StreamingAggregationStateManager import org.apache.spark.sql.expressions.scalalang.typed import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From 5d2096ed8ce2f59ad699c8817131988d37dd6695 Mon Sep 17 00:00:00 2001 From: Ryan Blue Date: Sun, 19 May 2019 21:30:20 -0700 Subject: [PATCH 62/70] [SPARK-27693][SQL] Add default catalog property Add a SQL config property for the default v2 catalog. Existing tests for regressions. Closes #24594 from rdblue/SPARK-27693-add-default-catalog-config. Authored-by: Ryan Blue Signed-off-by: Dongjoon Hyun --- .../apache/spark/sql/internal/SQLConf.scala | 77 +++++++++++++++++++ .../datasources/DataSourceResolution.scala | 4 +- .../command/PlanResolutionSuite.scala | 3 +- .../sql/sources/v2/DataSourceV2SQLSuite.scala | 3 +- 4 files changed, 82 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 759c22189d874..c7dc505f4bd3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1700,6 +1700,73 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) +<<<<<<< HEAD +||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property + + val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") + .doc("If the configuration property is set to true, java.time.Instant and " + + "java.time.LocalDate classes of Java 8 API are used as external types for " + + "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + + "and java.sql.Date are used for the same purpose.") + .booleanConf + .createWithDefault(false) + + val UTC_TIMESTAMP_FUNC_ENABLED = buildConf("spark.sql.legacy.utcTimestampFunc.enabled") + .doc("The configuration property enables the to_utc_timestamp() " + + "and from_utc_timestamp() functions.") + .booleanConf + .createWithDefault(false) + + val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") + .doc("The max length of a file that can be read by the binary file data source. " + + "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + + "The theoretical max is Int.MaxValue, though VMs might implement a smaller max.") + .internal() + .intConf + .createWithDefault(Int.MaxValue) + + val LEGACY_CAST_DATETIME_TO_STRING = + buildConf("spark.sql.legacy.typeCoercion.datetimeToString") + .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + + "with String") + .booleanConf + .createWithDefault(false) +======= + + val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") + .doc("If the configuration property is set to true, java.time.Instant and " + + "java.time.LocalDate classes of Java 8 API are used as external types for " + + "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + + "and java.sql.Date are used for the same purpose.") + .booleanConf + .createWithDefault(false) + + val UTC_TIMESTAMP_FUNC_ENABLED = buildConf("spark.sql.legacy.utcTimestampFunc.enabled") + .doc("The configuration property enables the to_utc_timestamp() " + + "and from_utc_timestamp() functions.") + .booleanConf + .createWithDefault(false) + + val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") + .doc("The max length of a file that can be read by the binary file data source. " + + "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + + "The theoretical max is Int.MaxValue, though VMs might implement a smaller max.") + .internal() + .intConf + .createWithDefault(Int.MaxValue) + + val LEGACY_CAST_DATETIME_TO_STRING = + buildConf("spark.sql.legacy.typeCoercion.datetimeToString") + .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + + "with String") + .booleanConf + .createWithDefault(false) + + val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog") + .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") + .stringConf + .createOptional +>>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property } /** @@ -2150,6 +2217,16 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) +<<<<<<< HEAD +||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property + def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) + +======= + def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) + + def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) + +>>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala index 7d34b6568a4fc..19881f69f158c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/DataSourceResolution.scala @@ -44,6 +44,8 @@ case class DataSourceResolution( override protected def lookupCatalog(name: String): CatalogPlugin = findCatalog(name) + def defaultCatalog: Option[CatalogPlugin] = conf.defaultV2Catalog.map(findCatalog) + override def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case CreateTableStatement( AsTableIdentifier(table), schema, partitionCols, bucketSpec, properties, @@ -77,7 +79,7 @@ case class DataSourceResolution( case create: CreateTableAsSelectStatement => // the provider was not a v1 source, convert to a v2 plan val CatalogObjectIdentifier(maybeCatalog, identifier) = create.tableName - val catalog = maybeCatalog + val catalog = maybeCatalog.orElse(defaultCatalog) .getOrElse(throw new AnalysisException( s"No catalog specified for table ${identifier.quoted} and no default catalog is set")) .asTableCatalog diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index 60801910c6dbc..06f7332086372 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -410,8 +410,7 @@ class PlanResolutionSuite extends AnalysisTest { } } - // TODO(rblue): enable this test after the default catalog is available - ignore("Test v2 CTAS with data source v2 provider") { + test("Test v2 CTAS with data source v2 provider") { val sql = s""" |CREATE TABLE IF NOT EXISTS mydb.page_view diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala index eaef458d38386..5b9071b59b9b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2SQLSuite.scala @@ -170,8 +170,7 @@ class DataSourceV2SQLSuite extends QueryTest with SharedSQLContext with BeforeAn checkAnswer(spark.internalCreateDataFrame(rdd, table.schema), spark.table("source")) } - // TODO(rblue): enable this test after the default catalog is available - ignore("CreateTableAsSelect: use v2 plan because provider is v2") { + test("CreateTableAsSelect: use v2 plan because provider is v2") { spark.sql(s"CREATE TABLE table_name USING $orc2 AS SELECT id, data FROM source") val testCatalog = spark.catalog("testcat").asTableCatalog From f7e63d64af41a106448680709842db6fb1c92c48 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:11:26 -0700 Subject: [PATCH 63/70] Fix merge conflicts --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c7dc505f4bd3e..196f5c79ab34c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1700,8 +1700,6 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) -<<<<<<< HEAD -||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") .doc("If the configuration property is set to true, java.time.Instant and " + @@ -1766,7 +1764,6 @@ object SQLConf { .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") .stringConf .createOptional ->>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property } /** @@ -2217,16 +2214,12 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) -<<<<<<< HEAD -||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) -======= def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) ->>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From 876d1a0f15189add80b4ca190ea47e5f4cd4fc9f Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:12:11 -0700 Subject: [PATCH 64/70] Revert "Fix merge conflicts" This reverts commit f7e63d64af41a106448680709842db6fb1c92c48. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 196f5c79ab34c..c7dc505f4bd3e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1700,6 +1700,8 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) +<<<<<<< HEAD +||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") .doc("If the configuration property is set to true, java.time.Instant and " + @@ -1764,6 +1766,7 @@ object SQLConf { .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") .stringConf .createOptional +>>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property } /** @@ -2214,12 +2217,16 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) +<<<<<<< HEAD +||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) +======= def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) +>>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From b714508522df3cff48d429a7df19266967b25653 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:15:56 -0700 Subject: [PATCH 65/70] FIx merge conflicts again --- .../apache/spark/sql/internal/SQLConf.scala | 70 ------------------- 1 file changed, 70 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index c7dc505f4bd3e..cbc57066163c1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1700,73 +1700,11 @@ object SQLConf { "a SparkConf entry.") .booleanConf .createWithDefault(true) -<<<<<<< HEAD -||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property - - val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") - .doc("If the configuration property is set to true, java.time.Instant and " + - "java.time.LocalDate classes of Java 8 API are used as external types for " + - "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + - "and java.sql.Date are used for the same purpose.") - .booleanConf - .createWithDefault(false) - - val UTC_TIMESTAMP_FUNC_ENABLED = buildConf("spark.sql.legacy.utcTimestampFunc.enabled") - .doc("The configuration property enables the to_utc_timestamp() " + - "and from_utc_timestamp() functions.") - .booleanConf - .createWithDefault(false) - - val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") - .doc("The max length of a file that can be read by the binary file data source. " + - "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + - "The theoretical max is Int.MaxValue, though VMs might implement a smaller max.") - .internal() - .intConf - .createWithDefault(Int.MaxValue) - - val LEGACY_CAST_DATETIME_TO_STRING = - buildConf("spark.sql.legacy.typeCoercion.datetimeToString") - .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + - "with String") - .booleanConf - .createWithDefault(false) -======= - - val DATETIME_JAVA8API_ENABLED = buildConf("spark.sql.datetime.java8API.enabled") - .doc("If the configuration property is set to true, java.time.Instant and " + - "java.time.LocalDate classes of Java 8 API are used as external types for " + - "Catalyst's TimestampType and DateType. If it is set to false, java.sql.Timestamp " + - "and java.sql.Date are used for the same purpose.") - .booleanConf - .createWithDefault(false) - - val UTC_TIMESTAMP_FUNC_ENABLED = buildConf("spark.sql.legacy.utcTimestampFunc.enabled") - .doc("The configuration property enables the to_utc_timestamp() " + - "and from_utc_timestamp() functions.") - .booleanConf - .createWithDefault(false) - - val SOURCES_BINARY_FILE_MAX_LENGTH = buildConf("spark.sql.sources.binaryFile.maxLength") - .doc("The max length of a file that can be read by the binary file data source. " + - "Spark will fail fast and not attempt to read the file if its length exceeds this value. " + - "The theoretical max is Int.MaxValue, though VMs might implement a smaller max.") - .internal() - .intConf - .createWithDefault(Int.MaxValue) - - val LEGACY_CAST_DATETIME_TO_STRING = - buildConf("spark.sql.legacy.typeCoercion.datetimeToString") - .doc("If it is set to true, date/timestamp will cast to string in binary comparisons " + - "with String") - .booleanConf - .createWithDefault(false) val DEFAULT_V2_CATALOG = buildConf("spark.sql.default.catalog") .doc("Name of the default v2 catalog, used when a catalog is not identified in queries") .stringConf .createOptional ->>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property } /** @@ -2217,16 +2155,8 @@ class SQLConf extends Serializable with Logging { def setCommandRejectsSparkCoreConfs: Boolean = getConf(SQLConf.SET_COMMAND_REJECTS_SPARK_CORE_CONFS) -<<<<<<< HEAD -||||||| parent of bc46feaced... [SPARK-27693][SQL] Add default catalog property - def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) - -======= - def castDatetimeToString: Boolean = getConf(SQLConf.LEGACY_CAST_DATETIME_TO_STRING) - def defaultV2Catalog: Option[String] = getConf(DEFAULT_V2_CATALOG) ->>>>>>> bc46feaced... [SPARK-27693][SQL] Add default catalog property /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ From c018fba565215ce6be449797f96a121931d576c7 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:27:39 -0700 Subject: [PATCH 66/70] Fix style --- .../spark/sql/execution/python/FlatMapGroupsInPandasExec.scala | 2 -- 1 file changed, 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala index 18b074b807807..c598b7c671a42 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/FlatMapGroupsInPandasExec.scala @@ -28,9 +28,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical.{AllTuples, ClusteredDistribution, Distribution, Partitioning} import org.apache.spark.sql.execution.{GroupedIterator, SparkPlan, UnaryExecNode} import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} /** * Physical node for [[org.apache.spark.sql.catalyst.plans.logical.FlatMapGroupsInPandas]] From 2cd8078aa82c4700f446075ad9ca5c3ac1fe692a Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 15:45:15 -0700 Subject: [PATCH 67/70] Fix test build. --- .../test/scala/org/apache/spark/sql/streaming/StreamTest.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala index 11acb534e9d69..fc72c940b922a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.streaming +import java.lang.Thread.UncaughtExceptionHandler + import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.language.experimental.macros From 5346dcf9b695588bfddfd34f5ff0708f88c8e2d1 Mon Sep 17 00:00:00 2001 From: francis0407 Date: Tue, 9 Apr 2019 21:45:46 +0800 Subject: [PATCH 68/70] [SPARK-27411][SQL] DataSourceV2Strategy should not eliminate subquery In DataSourceV2Strategy, it seems we eliminate the subqueries by mistake after normalizing filters. We have a sql with a scalar subquery: ``` scala val plan = spark.sql("select * from t2 where t2a > (select max(t1a) from t1)") plan.explain(true) ``` And we get the log info of DataSourceV2Strategy: ``` Pushing operators to csv:examples/src/main/resources/t2.txt Pushed Filters: Post-Scan Filters: isnotnull(t2a#30) Output: t2a#30, t2b#31 ``` The `Post-Scan Filters` should contain the scalar subquery, but we eliminate it by mistake. ``` == Parsed Logical Plan == 'Project [*] +- 'Filter ('t2a > scalar-subquery#56 []) : +- 'Project [unresolvedalias('max('t1a), None)] : +- 'UnresolvedRelation `t1` +- 'UnresolvedRelation `t2` == Analyzed Logical Plan == t2a: string, t2b: string Project [t2a#30, t2b#31] +- Filter (t2a#30 > scalar-subquery#56 []) : +- Aggregate [max(t1a#13) AS max(t1a)#63] : +- SubqueryAlias `t1` : +- RelationV2[t1a#13, t1b#14] csv:examples/src/main/resources/t1.txt +- SubqueryAlias `t2` +- RelationV2[t2a#30, t2b#31] csv:examples/src/main/resources/t2.txt == Optimized Logical Plan == Filter (isnotnull(t2a#30) && (t2a#30 > scalar-subquery#56 [])) : +- Aggregate [max(t1a#13) AS max(t1a)#63] : +- Project [t1a#13] : +- RelationV2[t1a#13, t1b#14] csv:examples/src/main/resources/t1.txt +- RelationV2[t2a#30, t2b#31] csv:examples/src/main/resources/t2.txt == Physical Plan == *(1) Project [t2a#30, t2b#31] +- *(1) Filter isnotnull(t2a#30) +- *(1) BatchScan[t2a#30, t2b#31] class org.apache.spark.sql.execution.datasources.v2.csv.CSVScan ``` ut Closes #24321 from francis0407/SPARK-27411. Authored-by: francis0407 Signed-off-by: Wenchen Fan --- .../datasources/v2/DataSourceV2Strategy.scala | 13 ++++++++++++- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 13 +++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d78b95336a76e..d807da442b9b8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -108,12 +108,23 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() +<<<<<<< HEAD val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, relation.output) +||||||| parent of 601fac2cb3... [SPARK-27411][SQL] DataSourceV2Strategy should not eliminate subquery + val normalizedFilters = DataSourceStrategy.normalizeFilters( + filters.filterNot(SubqueryExpression.hasSubquery), relation.output) +======= + val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery) + val normalizedFilters = DataSourceStrategy.normalizeFilters( + withoutSubquery, relation.output) +>>>>>>> 601fac2cb3... [SPARK-27411][SQL] DataSourceV2Strategy should not eliminate subquery // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. // `postScanFilters` and `pushedFilters` can overlap, e.g. the parquet row group filter. - val (pushedFilters, postScanFilters) = pushFilters(scanBuilder, normalizedFilters) + val (pushedFilters, postScanFiltersWithoutSubquery) = + pushFilters(scanBuilder, normalizedFilters) + val postScanFilters = postScanFiltersWithoutSubquery ++ withSubquery val (scan, output) = pruneColumns(scanBuilder, relation, project ++ postScanFilters) logInfo( s""" diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 379c9c4303cd6..7d8cb8f9b1849 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -387,6 +387,19 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { checkAnswer(df, (0 until 3).map(i => Row(i))) } } + + test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { + withTempView("t1") { + val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() + Seq(2, 3).toDF("a").createTempView("t1") + val df = t2.where("i < (select max(a) from t1)").select('i) + val subqueries = df.queryExecution.executedPlan.collect { + case p => p.subqueries + }.flatten + assert(subqueries.length == 1) + checkAnswer(df, (0 until 3).map(i => Row(i))) + } + } } From 17bb20c7dda3a37c14aa70b6a1d4c2fadc06e11b Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 18:16:08 -0700 Subject: [PATCH 69/70] Fix merge conflict --- .../execution/datasources/v2/DataSourceV2Strategy.scala | 7 ------- 1 file changed, 7 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index d807da442b9b8..61d98e8aab5d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -108,16 +108,9 @@ object DataSourceV2Strategy extends Strategy with PredicateHelper { case PhysicalOperation(project, filters, relation: DataSourceV2Relation) => val scanBuilder = relation.newScanBuilder() -<<<<<<< HEAD - val normalizedFilters = DataSourceStrategy.normalizeFilters(filters, relation.output) -||||||| parent of 601fac2cb3... [SPARK-27411][SQL] DataSourceV2Strategy should not eliminate subquery - val normalizedFilters = DataSourceStrategy.normalizeFilters( - filters.filterNot(SubqueryExpression.hasSubquery), relation.output) -======= val (withSubquery, withoutSubquery) = filters.partition(SubqueryExpression.hasSubquery) val normalizedFilters = DataSourceStrategy.normalizeFilters( withoutSubquery, relation.output) ->>>>>>> 601fac2cb3... [SPARK-27411][SQL] DataSourceV2Strategy should not eliminate subquery // `pushedFilters` will be pushed down and evaluated in the underlying data sources. // `postScanFilters` need to be evaluated after the scan. From 5a8ea0b7053b6456c44ab0a4e792a6dcf12d3463 Mon Sep 17 00:00:00 2001 From: mcheah Date: Thu, 6 Jun 2019 18:18:52 -0700 Subject: [PATCH 70/70] Fix build --- .../datasources/v2/DataSourceV2Strategy.scala | 2 +- .../spark/sql/sources/v2/DataSourceV2Suite.scala | 13 ------------- 2 files changed, 1 insertion(+), 14 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 61d98e8aab5d6..9889fd6731565 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -21,7 +21,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable import org.apache.spark.sql.{AnalysisException, Strategy} -import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, AttributeSet, Expression, PredicateHelper, SubqueryExpression} import org.apache.spark.sql.catalyst.planning.PhysicalOperation import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, CreateV2Table, DropTable, LogicalPlan, OverwriteByExpression, OverwritePartitionsDynamic, Repartition} import org.apache.spark.sql.execution.{FilterExec, ProjectExec, SparkPlan} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala index 7d8cb8f9b1849..379c9c4303cd6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala @@ -387,19 +387,6 @@ class DataSourceV2Suite extends QueryTest with SharedSQLContext { checkAnswer(df, (0 until 3).map(i => Row(i))) } } - - test("SPARK-27411: DataSourceV2Strategy should not eliminate subquery") { - withTempView("t1") { - val t2 = spark.read.format(classOf[SimpleDataSourceV2].getName).load() - Seq(2, 3).toDF("a").createTempView("t1") - val df = t2.where("i < (select max(a) from t1)").select('i) - val subqueries = df.queryExecution.executedPlan.collect { - case p => p.subqueries - }.flatten - assert(subqueries.length == 1) - checkAnswer(df, (0 until 3).map(i => Row(i))) - } - } }