From 9bbfe6ef4b5a418373c2250ad676233fb05df7f7 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 10:29:53 +0800 Subject: [PATCH 01/55] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". ## How was this patch tested? Manual test. --- .../main/scala/org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 84fe4bb711a4e..13410870cbb0e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -324,10 +324,10 @@ object SQLConf { val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + "none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") From 48cf108ed5c3298eb860d9735b439ac89d65765e Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 10:30:24 +0800 Subject: [PATCH 02/55] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". ## How was this patch tested? Manual test. --- .../datasources/parquet/ParquetOptions.scala | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..a81022409a119 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) From 5dbd3edf9e086433d3d3fe9c0ead887d799c61d3 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 10:34:29 +0800 Subject: [PATCH 03/55] spark.sql.parquet.compression.codec[SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". ## How was this patch tested? Manual test. --- docs/sql-programming-guide.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index f02f46236e2b0..51c1c068df1c4 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -954,7 +954,7 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession snappy Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + none, uncompressed, snappy, gzip, lzo. From 5124f1b560e942c0dc23af31336317a4b995dd8f Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 15:06:26 +0800 Subject: [PATCH 04/55] spark.sql.parquet.compression.codec[SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Manual test. --- docs/sql-programming-guide.md | 4 +- .../datasources/orc/OrcFileFormat.scala | 2 +- .../datasources/orc/OrcOptions.scala | 2 +- .../datasources/orc/OrcSourceSuite.scala | 13 ++-- .../spark/sql/hive/orc/OrcFileFormat.scala | 2 +- .../sql/hive/CompressionCodecSuite.scala | 61 +++++++++++++++++++ 6 files changed, 74 insertions(+), 10 deletions(-) create mode 100644 sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index 51c1c068df1c4..bd17d73683f22 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,7 +953,9 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: + Sets the compression codec use when writing Parquet files. If other compression codec + configuration was found through hive or parquet, the precedence would be `compression`, + `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, snappy, gzip, lzo. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index f7471cd7debce..b6c135bb23219 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -94,7 +94,7 @@ class OrcFileFormat conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) - conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodecClassName) conf.asInstanceOf[JobConf] .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index c866dd834a525..12930c8e3b971 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -41,7 +41,7 @@ class OrcOptions( * Compression codec to use. * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ - val compressionCodec: String = { + val compressionCodecClassName: String = { // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` // are in order of precedence from highest to lowest. val orcCompressionConf = parameters.get(COMPRESS.getAttribute) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 6f5f2fd795f74..beef8c2f1c3b8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -134,29 +134,30 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = spark.sessionState.conf val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) - assert(option.compressionCodec == "NONE") + assert(option.compressionCodecClassName == "NONE") } test("SPARK-21839: Add SQL config for ORC compression") { val conf = spark.sessionState.conf // Test if the default of spark.sql.orc.compression.codec is snappy - assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY") + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == "SNAPPY") // OrcOptions's parameters have a higher priority than SQL configuration. // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { - assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == "NONE") val map1 = Map(COMPRESS.getAttribute -> "zlib") val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") - assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") - assert(new OrcOptions(map2, conf).compressionCodec == "LZO") + assert(new OrcOptions(map1, conf).compressionCodecClassName == "ZLIB") + assert(new OrcOptions(map2, conf).compressionCodecClassName == "LZO") } // Test all the valid options of spark.sql.orc.compression.codec Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c => withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) { val expected = if (c == "UNCOMPRESSED") "NONE" else c - assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected) + assert( + new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == expected) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 95741c7b30289..5715893920ac6 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -74,7 +74,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodecClassName) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala new file mode 100644 index 0000000000000..8ce362b99e61a --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.hive.test.TestHiveSingleton +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SQLTestUtils + +class CompressionCodecSuite extends TestHiveSingleton with SQLTestUtils { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } +} From 6907a3ef86a2546fae91c22754796490a80effff Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 17:26:33 +0800 Subject: [PATCH 05/55] Make comression codec take effect in hive table writing. --- .../datasources/orc/OrcOptions.scala | 2 +- .../datasources/parquet/ParquetOptions.scala | 6 ++--- .../sql/hive/execution/HiveOptions.scala | 22 +++++++++++++++++++ .../sql/hive/execution/SaveAsHiveFile.scala | 4 ++++ 4 files changed, 30 insertions(+), 4 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index 12930c8e3b971..d6332b2c31384 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -61,7 +61,7 @@ class OrcOptions( object OrcOptions { // The ORC compression short names - private val shortOrcCompressionCodecNames = Map( + val shortOrcCompressionCodecNames = Map( "none" -> "NONE", "uncompressed" -> "NONE", "snappy" -> "SNAPPY", diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index a81022409a119..383599b7fce3d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale -import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.parquet.hadoop.ParquetOutputFormat +import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.internal.SQLConf @@ -28,7 +28,7 @@ import org.apache.spark.sql.internal.SQLConf /** * Options for the Parquet data source. */ -private[parquet] class ParquetOptions( +class ParquetOptions( @transient private val parameters: CaseInsensitiveMap[String], @transient private val sqlConf: SQLConf) extends Serializable { @@ -76,7 +76,7 @@ object ParquetOptions { val MERGE_SCHEMA = "mergeSchema" // The parquet compression short names - private val shortParquetCompressionCodecNames = Map( + val shortParquetCompressionCodecNames = Map( "none" -> CompressionCodecName.UNCOMPRESSED, "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 5c515515b9b9c..78f712b88d203 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -19,7 +19,16 @@ package org.apache.spark.sql.hive.execution import java.util.Locale +import scala.collection.JavaConverters._ + +import org.apache.hadoop.hive.ql.plan.TableDesc +import org.apache.orc.OrcConf.COMPRESS +import org.apache.parquet.hadoop.ParquetOutputFormat + import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.internal.SQLConf /** * Options for the Hive data source. Note that rule `DetermineHiveSerde` will extract Hive @@ -102,4 +111,17 @@ object HiveOptions { "collectionDelim" -> "colelction.delim", "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } + + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)]= { + val tableProps = tableInfo.getProperties.asScala.toMap + tableInfo.getOutputFileFormatClassName.toLowerCase match { + case formatName if formatName.endsWith("parquetoutputformat") => + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodec + Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) + case formatName if formatName.endsWith("orcoutputformat") => + val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec + Option((COMPRESS.getAttribute, compressionCodec)) + case _ => None + } + } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 63657590e5e79..734d8350f3870 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -68,6 +68,10 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { .get("mapreduce.output.fileoutputformat.compress.type")) } + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } + val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, From 67e40d4d7fd3b6a9e4526ce17bf6d4eadb05b2b8 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 25 Dec 2017 20:08:11 +0800 Subject: [PATCH 06/55] Modify test --- .../sql/hive/CompressionCodecSuite.scala | 353 ++++++++++++++++-- 1 file changed, 326 insertions(+), 27 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 8ce362b99e61a..baccf88e4b0b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -17,45 +17,344 @@ package org.apache.spark.sql.hive +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.orc.OrcConf.COMPRESS import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetOptions +import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SQLTestUtils -class CompressionCodecSuite extends TestHiveSingleton with SQLTestUtils { - test("Test `spark.sql.parquet.compression.codec` config") { - Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c => - withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { - val expected = if (c == "NONE") "UNCOMPRESSED" else c - val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) - assert(option.compressionCodecClassName == expected) +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { + import spark.implicits._ + + private val maxRecordNum = 100000 + + private def getConvertMetastoreConfName(format: String): String = format match { + case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key + case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key + } + + private def getSparkCompressionConfName(format: String): String = format match { + case "parquet" => SQLConf.PARQUET_COMPRESSION.key + case "orc" => SQLConf.ORC_COMPRESSION.key + } + + private def getHiveCompressPropName(format: String): String = { + format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + } + + private def getTableCompressionCodec(path: String, format: String): String = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = format match { + case "parquet" => for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + case "orc" => new File(path).listFiles().filter{ file => + file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" + }.map { orcFile => + OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString + }.toSeq + } + + assert(codecs.distinct.length == 1) + codecs.head + } + + private def writeDataToTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + dataSourceName: String = "table_source"): Unit = { + val tblProperties = compressionCodec match { + case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" + case _ => "" + } + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p int)" else "" + sql( + s""" + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) + + val partitionInsert = if (isPartitioned) s"partition (p=10000)" else "" + sql( + s""" + |INSERT OVERWRITE TABLE $tableName + |$partitionInsert + |SELECT * FROM $dataSourceName + """.stripMargin) + } + + private def getTableSize(path: String): Long = { + val dir = new File(path) + val files = dir.listFiles().filter(_.getName.startsWith("part-")) + files.map(_.length()).sum + } + + private def getDataSizeByFormat(format: String, compressionCodec: Option[String], isPartitioned: Boolean): Long = { + var totalSize = 0L + val tableName = s"tbl_$format${compressionCodec.mkString}" + withTempView("datasource_table") { + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("datasource_table") + withSQLConf(getSparkCompressionConfName(format) -> compressionCodec.getOrElse("uncompressed")) { + withTempDir { tmpDir => + withTable(tableName) { + writeDataToTable(tmpDir, tableName, isPartitioned, format, compressionCodec, "datasource_table") + val partition = if (isPartitioned) "p=10000" else "" + val path =s"${tmpDir.getPath.stripSuffix("/")}/${tableName}/$partition" + totalSize = getTableSize(path) + } + } } } + assert(totalSize > 0L) + totalSize } - test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { - // When "compression" is configured, it should be the first choice. - withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { - val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") - val option = new ParquetOptions(props, spark.sessionState.conf) - assert(option.compressionCodecClassName == "UNCOMPRESSED") + private def checkCompressionCodecForTable( + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String]) + (assertion: (String, Long) => Unit): Unit = { + val tableName = s"tbl_$format${isPartitioned}" + withTempDir { tmpDir => + withTable(tableName) { + writeDataToTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + val partition = if (isPartitioned) "p=10000" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/${tableName}/$partition" + val relCompressionCodec = getTableCompressionCodec(path, format) + val tableSize = getTableSize(path) + assertion(relCompressionCodec, tableSize) + } } + } - // When "compression" is not configured, "parquet.compression" should be the preferred choice. - withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { - val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") - val option = new ParquetOptions(props, spark.sessionState.conf) - assert(option.compressionCodecClassName == "GZIP") + private def checkTableCompressionCodecForCodecs( + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { + tableCompressionCodecs.foreach { tableCompression => + compressionCodecs.foreach { sessionCompressionCodec => + withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { + // 'tableCompression = null' means no table-level compression + val compression = Option(tableCompression) + checkCompressionCodecForTable(format, isPartitioned, compression) { + case (realCompressionCodec, tableSize) => assertionCompressionCodec(compression, + sessionCompressionCodec, realCompressionCodec, tableSize) + } + } + } + } } + } + + // To check if the compressionCodec takes effect, we check the data size with uncompressed size. + // and because partitioned table's schema may different with non-partitioned table's schema when + // convertMetastore is true, e.g. parquet, we save them independently. + private val partitionedParquetTableUncompressedSize = getDataSizeByFormat("parquet", None, true) + private val nonpartitionedParquetTableUncompressedSize = getDataSizeByFormat("parquet", None, false) + + // Orc data seems be consistent within partitioned table and non-partitioned table, + // but we just make the code look the same. + private val partitionedOrcTableUncompressedSize = getDataSizeByFormat("orc", None, true) + private val nonpartitionedOrcTableUncompressedSize = getDataSizeByFormat("orc", None, false) + + // When the amount of data is small, compressed data size may be smaller than uncompressed one, + // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. + // When convertMetastore is false, the uncompressed table size should be same as + // `partitionedParquetTableUncompressedSize`, regardless of whether there is a partition. + private def checkTableSize(format: String, compressionCodec: String, + ispartitioned: Boolean, convertMetastore: Boolean, tableSize: Long): Boolean = { + format match { + case "parquet" => val uncompressedSize = + if(!convertMetastore || ispartitioned) partitionedParquetTableUncompressedSize + else nonpartitionedParquetTableUncompressedSize - // When both "compression" and "parquet.compression" are not configured, - // spark.sql.parquet.compression.codec should be the right choice. - withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { - val props = Map.empty[String, String] - val option = new ParquetOptions(props, spark.sessionState.conf) - assert(option.compressionCodecClassName == "SNAPPY") + if(compressionCodec == "UNCOMPRESSED") tableSize == uncompressedSize + else tableSize != uncompressedSize + case "orc" => val uncompressedSize = + if(!convertMetastore || ispartitioned) partitionedOrcTableUncompressedSize + else nonpartitionedOrcTableUncompressedSize + + if(compressionCodec == "NONE") tableSize == uncompressedSize + else tableSize != uncompressedSize + case _ => false } } -} + + private def testCompressionCodec(testCondition: String)(f: => Unit): Unit = { + test("[SPARK-21786] - Check the priority between table-level compression and " + + s"session-level compression $testCondition") { + withTempView("table_source") { + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + f + } + } + } + + testCompressionCodec("when table-level and session-level compression are both configured and " + + "convertMetastore is false") { + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + // For tables with table-level compression property, when + // 'spark.sql.hive.convertMetastore[Parquet|Orc]' was set to 'false', partitioned tables + // and non-partitioned tables will always take the table-level compression + // configuration first and ignore session compression configuration. + // Check for partitioned table, when convertMetastore is false + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = true, + convertMetastore = false, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect table-level take effect + assert(tableCompressionCodec.get == realCompressionCodec) + assert(checkTableSize(format, tableCompressionCodec.get, true, false, tableSize)) + } + + // Check for non-partitioned table, when convertMetastoreParquet is false + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = false, + convertMetastore = false, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect table-level take effect + assert(tableCompressionCodec.get == realCompressionCodec) + assert(checkTableSize(format, tableCompressionCodec.get, false, false, tableSize)) + } + } + + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + testCompressionCodec("when there's no table-level compression and convertMetastore is false") { + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + // For tables without table-level compression property, session-level compression + // configuration will take effect. + // Check for partitioned table, when convertMetastore is false + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = true, + convertMetastore = false, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, true, false, tableSize)) + } + + // Check for non-partitioned table, when convertMetastore is false + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = false, + convertMetastore = false, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, false, false, tableSize)) + } + } + + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + testCompressionCodec("when table-level and session-level compression are both configured and " + + "convertMetastore is true") { + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + // For tables with table-level compression property, when + // 'spark.sql.hive.convertMetastore[Parquet|Orc]' was set to 'true', partitioned tables + // will always take the table-level compression configuration first, but non-partitioned + // tables will take the session-level compression configuration. + // Check for partitioned table, when convertMetastore is true + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = true, + convertMetastore = true, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect table-level take effect + assert(tableCompressionCodec.get == realCompressionCodec) + assert(checkTableSize(format, tableCompressionCodec.get, true, true, tableSize)) + } + + // Check for non-partitioned table, when convertMetastore is true + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = false, + convertMetastore = true, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, false, true, tableSize)) + } + } + + checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } + + testCompressionCodec("when there's no table-level compression and convertMetastore is true") { + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + // For tables without table-level compression property, session-level compression + // configuration will take effect. + // Check for partitioned table, when convertMetastore is true + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = true, + convertMetastore = true, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, true, true, tableSize)) + } + + // Check for non-partitioned table, when convertMetastore is true + checkTableCompressionCodecForCodecs( + format = format, + isPartitioned = false, + convertMetastore = true, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, false, true, tableSize)) + } + } + + checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + } +} \ No newline at end of file From e2526ca1bb72e54c03d977b8678bd14b28c83585 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 26 Dec 2017 13:38:10 +0800 Subject: [PATCH 07/55] Separate the pr --- docs/sql-programming-guide.md | 6 ++---- .../org/apache/spark/sql/internal/SQLConf.scala | 4 ++-- .../execution/datasources/orc/OrcFileFormat.scala | 2 +- .../sql/execution/datasources/orc/OrcOptions.scala | 2 +- .../datasources/parquet/ParquetOptions.scala | 12 ++---------- .../execution/datasources/orc/OrcSourceSuite.scala | 12 ++++++------ .../apache/spark/sql/hive/orc/OrcFileFormat.scala | 2 +- 7 files changed, 15 insertions(+), 25 deletions(-) diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index bd17d73683f22..f02f46236e2b0 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,10 +953,8 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. If other compression codec - configuration was found through hive or parquet, the precedence would be `compression`, - `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: - none, uncompressed, snappy, gzip, lzo. + Sets the compression codec use when writing Parquet files. Acceptable values include: + uncompressed, snappy, gzip, lzo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 13410870cbb0e..84fe4bb711a4e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -324,10 +324,10 @@ object SQLConf { val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "none, uncompressed, snappy, gzip, lzo.") + "uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index b6c135bb23219..f7471cd7debce 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -94,7 +94,7 @@ class OrcFileFormat conf.set(MAPRED_OUTPUT_SCHEMA.getAttribute, dataSchema.catalogString) - conf.set(COMPRESS.getAttribute, orcOptions.compressionCodecClassName) + conf.set(COMPRESS.getAttribute, orcOptions.compressionCodec) conf.asInstanceOf[JobConf] .setOutputFormat(classOf[org.apache.orc.mapred.OrcOutputFormat[OrcStruct]]) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index d6332b2c31384..faa1f4a4943f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -41,7 +41,7 @@ class OrcOptions( * Compression codec to use. * Acceptable values are defined in [[shortOrcCompressionCodecNames]]. */ - val compressionCodecClassName: String = { + val compressionCodec: String = { // `compression`, `orc.compress`(i.e., OrcConf.COMPRESS), and `spark.sql.orc.compression.codec` // are in order of precedence from highest to lowest. val orcCompressionConf = parameters.get(COMPRESS.getAttribute) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 383599b7fce3d..a722dbfb4fadf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale -import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -43,15 +42,8 @@ class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and - // `spark.sql.parquet.compression.codec` - // are in order of precedence from highest to lowest. - val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) - val codecName = parameters - .get("compression") - .orElse(parquetCompressionConf) - .getOrElse(sqlConf.parquetCompressionCodec) - .toLowerCase(Locale.ROOT) + val codecName = parameters.getOrElse("compression", + sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index beef8c2f1c3b8..e2c09d3ac3cae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -134,22 +134,22 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { test("SPARK-18433: Improve DataSource option keys to be more case-insensitive") { val conf = spark.sessionState.conf val option = new OrcOptions(Map(COMPRESS.getAttribute.toUpperCase(Locale.ROOT) -> "NONE"), conf) - assert(option.compressionCodecClassName == "NONE") + assert(option.compressionCodec == "NONE") } test("SPARK-21839: Add SQL config for ORC compression") { val conf = spark.sessionState.conf // Test if the default of spark.sql.orc.compression.codec is snappy - assert(new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == "SNAPPY") + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "SNAPPY") // OrcOptions's parameters have a higher priority than SQL configuration. // `compression` -> `orc.compression` -> `spark.sql.orc.compression.codec` withSQLConf(SQLConf.ORC_COMPRESSION.key -> "uncompressed") { - assert(new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == "NONE") + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == "NONE") val map1 = Map(COMPRESS.getAttribute -> "zlib") val map2 = Map(COMPRESS.getAttribute -> "zlib", "compression" -> "lzo") - assert(new OrcOptions(map1, conf).compressionCodecClassName == "ZLIB") - assert(new OrcOptions(map2, conf).compressionCodecClassName == "LZO") + assert(new OrcOptions(map1, conf).compressionCodec == "ZLIB") + assert(new OrcOptions(map2, conf).compressionCodec == "LZO") } // Test all the valid options of spark.sql.orc.compression.codec @@ -157,7 +157,7 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) { val expected = if (c == "UNCOMPRESSED") "NONE" else c assert( - new OrcOptions(Map.empty[String, String], conf).compressionCodecClassName == expected) + new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected) } } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala index 5715893920ac6..95741c7b30289 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/orc/OrcFileFormat.scala @@ -74,7 +74,7 @@ class OrcFileFormat extends FileFormat with DataSourceRegister with Serializable val configuration = job.getConfiguration - configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodecClassName) + configuration.set(COMPRESS.getAttribute, orcOptions.compressionCodec) configuration match { case conf: JobConf => conf.setOutputFormat(classOf[OrcOutputFormat]) From 8ae86ee11de1e2a4d22bf0e7111478167ce5a8b9 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 26 Dec 2017 16:09:11 +0800 Subject: [PATCH 08/55] Add test case with the table containing mixed compression codec --- .../sql/hive/CompressionCodecSuite.scala | 387 ++++++++---------- 1 file changed, 172 insertions(+), 215 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index baccf88e4b0b0..d7b33c151e5cf 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -21,40 +21,61 @@ import java.io.File import scala.collection.JavaConverters._ +import org.scalatest.BeforeAndAfterAll + import org.apache.hadoop.fs.Path import org.apache.orc.OrcConf.COMPRESS import org.apache.parquet.hadoop.ParquetOutputFormat -import org.apache.spark.sql.execution.datasources.parquet.ParquetTest +import org.apache.spark.sql.execution.datasources.orc.OrcOptions +import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} import org.apache.spark.sql.hive.orc.OrcFileOperator import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.internal.SQLConf -class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { +class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with BeforeAndAfterAll { import spark.implicits._ - private val maxRecordNum = 100000 + override def beforeAll(): Unit = { + super.beforeAll() + (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") + } - private def getConvertMetastoreConfName(format: String): String = format match { + override def afterAll(): Unit = { + try { + spark.catalog.dropTempView("table_source") + } finally { + super.afterAll() + } + } + + private val maxRecordNum = 500 + + private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key case "orc" => HiveUtils.CONVERT_METASTORE_ORC.key } - private def getSparkCompressionConfName(format: String): String = format match { + private def getSparkCompressionConfName(format: String): String = format.toLowerCase match { case "parquet" => SQLConf.PARQUET_COMPRESSION.key case "orc" => SQLConf.ORC_COMPRESSION.key } - private def getHiveCompressPropName(format: String): String = { + private def getHiveCompressPropName(format: String): String = format.toLowerCase match { + case "parquet" => ParquetOutputFormat.COMPRESSION + case "orc" => COMPRESS.getAttribute + } + + private def normalizeCodecName(format: String, name: String): String = { format.toLowerCase match { - case "parquet" => ParquetOutputFormat.COMPRESSION - case "orc" => COMPRESS.getAttribute + case "parquet" => ParquetOptions.shortParquetCompressionCodecNames(name).name() + case "orc" => OrcOptions.shortOrcCompressionCodecNames(name) } } - private def getTableCompressionCodec(path: String, format: String): String = { + private def getTableCompressionCodec(path: String, format: String): Seq[String] = { val hadoopConf = spark.sessionState.newHadoopConf() - val codecs = format match { + val codecs = format.toLowerCase match { case "parquet" => for { footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) block <- footer.getParquetMetadata.getBlocks.asScala @@ -66,23 +87,20 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString }.toSeq } - - assert(codecs.distinct.length == 1) - codecs.head + codecs.distinct } - private def writeDataToTable( - rootDir: File, - tableName: String, - isPartitioned: Boolean, - format: String, - compressionCodec: Option[String], - dataSourceName: String = "table_source"): Unit = { + private def createTable( + rootDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String]): Unit = { val tblProperties = compressionCodec match { case Some(prop) => s"TBLPROPERTIES('${getHiveCompressPropName(format)}'='$prop')" case _ => "" } - val partitionCreate = if (isPartitioned) "PARTITIONED BY (p int)" else "" + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" sql( s""" |CREATE TABLE $tableName(a int) @@ -91,13 +109,17 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' |$tblProperties """.stripMargin) + } - val partitionInsert = if (isPartitioned) s"partition (p=10000)" else "" + private def writeDataToTable( + tableName: String, + partition: Option[String]): Unit = { + val partitionInsert = partition.map(p => s"partition ($p)").mkString sql( s""" - |INSERT OVERWRITE TABLE $tableName + |INSERT INTO TABLE $tableName |$partitionInsert - |SELECT * FROM $dataSourceName + |SELECT * FROM table_source """.stripMargin) } @@ -107,19 +129,19 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { files.map(_.length()).sum } - private def getDataSizeByFormat(format: String, compressionCodec: Option[String], isPartitioned: Boolean): Long = { + private def getUncompressedDataSizeByFormat( + format: String, isPartitioned: Boolean): Long = { var totalSize = 0L - val tableName = s"tbl_$format${compressionCodec.mkString}" - withTempView("datasource_table") { - (0 until maxRecordNum).toDF("a").createOrReplaceTempView("datasource_table") - withSQLConf(getSparkCompressionConfName(format) -> compressionCodec.getOrElse("uncompressed")) { - withTempDir { tmpDir => - withTable(tableName) { - writeDataToTable(tmpDir, tableName, isPartitioned, format, compressionCodec, "datasource_table") - val partition = if (isPartitioned) "p=10000" else "" - val path =s"${tmpDir.getPath.stripSuffix("/")}/${tableName}/$partition" - totalSize = getTableSize(path) - } + val tableName = s"tbl_$format" + val codecName = normalizeCodecName(format, "uncompressed") + withSQLConf(getSparkCompressionConfName(format) -> codecName) { + withTempDir { tmpDir => + withTable(tableName) { + createTable(tmpDir, tableName, isPartitioned, format, Option(codecName)) + val partition = if (isPartitioned) Some("p='test'") else None + writeDataToTable(tableName, partition) + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tableName/${partition.mkString}" + totalSize = getTableSize(path) } } } @@ -128,30 +150,32 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { } private def checkCompressionCodecForTable( - format: String, - isPartitioned: Boolean, - compressionCodec: Option[String]) - (assertion: (String, Long) => Unit): Unit = { - val tableName = s"tbl_$format${isPartitioned}" + format: String, + isPartitioned: Boolean, + compressionCodec: Option[String]) + (assertion: (String, Long) => Unit): Unit = { + val tableName = s"tbl_$format$isPartitioned" withTempDir { tmpDir => withTable(tableName) { - writeDataToTable(tmpDir, tableName, isPartitioned, format, compressionCodec) - val partition = if (isPartitioned) "p=10000" else "" - val path = s"${tmpDir.getPath.stripSuffix("/")}/${tableName}/$partition" - val relCompressionCodec = getTableCompressionCodec(path, format) + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + val partition = if (isPartitioned) Some("p='test'") else None + writeDataToTable(tableName, partition) + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tableName/${partition.mkString}" + val relCompressionCodecs = getTableCompressionCodec(path, format) + assert(relCompressionCodecs.length == 1) val tableSize = getTableSize(path) - assertion(relCompressionCodec, tableSize) + assertion(relCompressionCodecs.head, tableSize) } } } private def checkTableCompressionCodecForCodecs( - format: String, - isPartitioned: Boolean, - convertMetastore: Boolean, - compressionCodecs: List[String], - tableCompressionCodecs: List[String]) - (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { + format: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + compressionCodecs: List[String], + tableCompressionCodecs: List[String]) + (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString) { tableCompressionCodecs.foreach { tableCompression => compressionCodecs.foreach { sessionCompressionCodec => @@ -168,193 +192,126 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest { } } - // To check if the compressionCodec takes effect, we check the data size with uncompressed size. - // and because partitioned table's schema may different with non-partitioned table's schema when - // convertMetastore is true, e.g. parquet, we save them independently. - private val partitionedParquetTableUncompressedSize = getDataSizeByFormat("parquet", None, true) - private val nonpartitionedParquetTableUncompressedSize = getDataSizeByFormat("parquet", None, false) - - // Orc data seems be consistent within partitioned table and non-partitioned table, - // but we just make the code look the same. - private val partitionedOrcTableUncompressedSize = getDataSizeByFormat("orc", None, true) - private val nonpartitionedOrcTableUncompressedSize = getDataSizeByFormat("orc", None, false) - - // When the amount of data is small, compressed data size may be smaller than uncompressed one, + // When the amount of data is small, compressed data size may be larger than uncompressed one, // so we just check the difference when compressionCodec is not NONE or UNCOMPRESSED. - // When convertMetastore is false, the uncompressed table size should be same as - // `partitionedParquetTableUncompressedSize`, regardless of whether there is a partition. - private def checkTableSize(format: String, compressionCodec: String, - ispartitioned: Boolean, convertMetastore: Boolean, tableSize: Long): Boolean = { + private def checkTableSize( + format: String, + compressionCodec: String, + isPartitioned: Boolean, + convertMetastore: Boolean, + tableSize: Long): Boolean = { format match { - case "parquet" => val uncompressedSize = - if(!convertMetastore || ispartitioned) partitionedParquetTableUncompressedSize - else nonpartitionedParquetTableUncompressedSize - - if(compressionCodec == "UNCOMPRESSED") tableSize == uncompressedSize - else tableSize != uncompressedSize - case "orc" => val uncompressedSize = - if(!convertMetastore || ispartitioned) partitionedOrcTableUncompressedSize - else nonpartitionedOrcTableUncompressedSize - - if(compressionCodec == "NONE") tableSize == uncompressedSize - else tableSize != uncompressedSize + case "parquet" => + val uncompressedSize = if (!convertMetastore || isPartitioned) { + getUncompressedDataSizeByFormat(format, isPartitioned = true) + } else { + getUncompressedDataSizeByFormat(format, isPartitioned = false) + } + + if (compressionCodec == "UNCOMPRESSED") { + tableSize == uncompressedSize + } else { + tableSize != uncompressedSize + } + case "orc" => + val uncompressedSize = if (!convertMetastore || isPartitioned) { + getUncompressedDataSizeByFormat(format, isPartitioned = true) + } else { + getUncompressedDataSizeByFormat(format, isPartitioned = false) + } + + if (compressionCodec == "NONE") { + tableSize == uncompressedSize + } else { + tableSize != uncompressedSize + } case _ => false } } - private def testCompressionCodec(testCondition: String)(f: => Unit): Unit = { - test("[SPARK-21786] - Check the priority between table-level compression and " + - s"session-level compression $testCondition") { - withTempView("table_source") { - (0 until maxRecordNum).toDF("a").createOrReplaceTempView("table_source") - f + def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // For non-partitioned table and when convertMetastore is false, Expect session-level + // take effect, and in other cases expect table-level take effect + val expectCompressionCodec = + if (convertMetastore && !isPartitioned) sessionCompressionCodec + else tableCompressionCodec.get + + assert(expectCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, expectCompressionCodec, + isPartitioned, convertMetastore, tableSize)) + } } } } - testCompressionCodec("when table-level and session-level compression are both configured and " + - "convertMetastore is false") { - def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { - // For tables with table-level compression property, when - // 'spark.sql.hive.convertMetastore[Parquet|Orc]' was set to 'false', partitioned tables - // and non-partitioned tables will always take the table-level compression - // configuration first and ignore session compression configuration. - // Check for partitioned table, when convertMetastore is false - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = true, - convertMetastore = false, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect table-level take effect - assert(tableCompressionCodec.get == realCompressionCodec) - assert(checkTableSize(format, tableCompressionCodec.get, true, false, tableSize)) - } - - // Check for non-partitioned table, when convertMetastoreParquet is false - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = false, - convertMetastore = false, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect table-level take effect - assert(tableCompressionCodec.get == realCompressionCodec) - assert(checkTableSize(format, tableCompressionCodec.get, false, false, tableSize)) + def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, + isPartitioned, convertMetastore, tableSize)) + } } } + } + test("both table-level and session-level compression are set") { checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) } - testCompressionCodec("when there's no table-level compression and convertMetastore is false") { - def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { - // For tables without table-level compression property, session-level compression - // configuration will take effect. - // Check for partitioned table, when convertMetastore is false - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = true, - convertMetastore = false, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, true, false, tableSize)) - } - - // Check for non-partitioned table, when convertMetastore is false - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = false, - convertMetastore = false, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, false, false, tableSize)) - } - } - + test("table-level compression is not set but session-level compressions is set ") { checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) } - testCompressionCodec("when table-level and session-level compression are both configured and " + - "convertMetastore is true") { - def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { - // For tables with table-level compression property, when - // 'spark.sql.hive.convertMetastore[Parquet|Orc]' was set to 'true', partitioned tables - // will always take the table-level compression configuration first, but non-partitioned - // tables will take the session-level compression configuration. - // Check for partitioned table, when convertMetastore is true - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = true, - convertMetastore = true, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect table-level take effect - assert(tableCompressionCodec.get == realCompressionCodec) - assert(checkTableSize(format, tableCompressionCodec.get, true, true, tableSize)) - } - - // Check for non-partitioned table, when convertMetastore is true - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = false, - convertMetastore = true, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, false, true, tableSize)) + def checkTableWriteWithCompressionCodecs(format: String, compressCodecs: List[String]): Unit = { + Seq(true, false).foreach { isPartitioned => + Seq(true, false).foreach { convertMetastore => + withTempDir { tmpDir => + val tableName = s"tbl_$format$isPartitioned" + createTable(tmpDir, tableName, isPartitioned, format, None) + withTable(tableName) { + compressCodecs.foreach { compressionCodec => + val partition = if (isPartitioned) Some(s"p='$compressionCodec'") else None + withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, + getSparkCompressionConfName(format) -> compressionCodec + ) { writeDataToTable(tableName, partition) } + } + val tablePath = s"${tmpDir.getPath.stripSuffix("/")}/$tableName" + val relCompressionCodecs = + if (isPartitioned) compressCodecs.flatMap { codec => + getTableCompressionCodec(s"$tablePath/p=$codec", format) + } else getTableCompressionCodec(tablePath, format) + + assert(relCompressionCodecs.distinct.sorted == compressCodecs.sorted) + val recordsNum = sql(s"SELECT * from $tableName").count() + assert(recordsNum == maxRecordNum * compressCodecs.length) + } + } } } - - checkForTableWithCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) - checkForTableWithCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) } - testCompressionCodec("when there's no table-level compression and convertMetastore is true") { - def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { - // For tables without table-level compression property, session-level compression - // configuration will take effect. - // Check for partitioned table, when convertMetastore is true - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = true, - convertMetastore = true, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, true, true, tableSize)) - } - - // Check for non-partitioned table, when convertMetastore is true - checkTableCompressionCodecForCodecs( - format = format, - isPartitioned = false, - convertMetastore = true, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, false, true, tableSize)) - } - } - - checkForTableWithoutCompressProp("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) - checkForTableWithoutCompressProp("orc", List("NONE", "SNAPPY", "ZLIB")) + test("test table containing mixed compression codec") { + checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) + checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) } } \ No newline at end of file From 94ac716551c48ffd48b1f6590ceeac0a95888490 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 26 Dec 2017 16:26:46 +0800 Subject: [PATCH 09/55] Revert back --- .../spark/sql/execution/datasources/orc/OrcSourceSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index e2c09d3ac3cae..6f5f2fd795f74 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -156,8 +156,7 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll { Seq("NONE", "UNCOMPRESSED", "SNAPPY", "ZLIB", "LZO").foreach { c => withSQLConf(SQLConf.ORC_COMPRESSION.key -> c) { val expected = if (c == "UNCOMPRESSED") "NONE" else c - assert( - new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected) + assert(new OrcOptions(Map.empty[String, String], conf).compressionCodec == expected) } } } From 43e041f2b7c350ddafea447a356b4df572ce0b2a Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 26 Dec 2017 16:30:12 +0800 Subject: [PATCH 10/55] Revert back --- .../scala/org/apache/spark/sql/hive/execution/HiveOptions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 78f712b88d203..bffb22d76870b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -116,7 +116,7 @@ object HiveOptions { val tableProps = tableInfo.getProperties.asScala.toMap tableInfo.getOutputFileFormatClassName.toLowerCase match { case formatName if formatName.endsWith("parquetoutputformat") => - val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodec + val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) case formatName if formatName.endsWith("orcoutputformat") => val compressionCodec = new OrcOptions(tableProps, sqlConf).compressionCodec From ee0c5587374ed33e29a763f8764263114a9b57ab Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 26 Dec 2017 16:33:51 +0800 Subject: [PATCH 11/55] Add a new line at the of file --- .../scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index d7b33c151e5cf..93cdd948bfa49 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -314,4 +314,4 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo checkTableWriteWithCompressionCodecs("parquet", List("UNCOMPRESSED", "SNAPPY", "GZIP")) checkTableWriteWithCompressionCodecs("orc", List("NONE", "SNAPPY", "ZLIB")) } -} \ No newline at end of file +} From e9f705d0ad783da5bd091632e98a6151d4d21cb6 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 2 Jan 2018 09:52:53 +0800 Subject: [PATCH 12/55] Fix scala style --- .../scala/org/apache/spark/sql/hive/execution/HiveOptions.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index bffb22d76870b..192f6bbebc7d1 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -112,7 +112,7 @@ object HiveOptions { "mapkeyDelim" -> "mapkey.delim", "lineDelim" -> "line.delim").map { case (k, v) => k.toLowerCase(Locale.ROOT) -> v } - def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)]= { + def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { val tableProps = tableInfo.getProperties.asScala.toMap tableInfo.getOutputFileFormatClassName.toLowerCase match { case formatName if formatName.endsWith("parquetoutputformat") => From d3aa7a01320b6af2866d7cf7c4f178eb23eae3ad Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Tue, 2 Jan 2018 10:20:41 +0800 Subject: [PATCH 13/55] Fix scala style --- .../org/apache/spark/sql/hive/CompressionCodecSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 93cdd948bfa49..71aa95da17e3e 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -21,11 +21,10 @@ import java.io.File import scala.collection.JavaConverters._ -import org.scalatest.BeforeAndAfterAll - import org.apache.hadoop.fs.Path import org.apache.orc.OrcConf.COMPRESS import org.apache.parquet.hadoop.ParquetOutputFormat +import org.scalatest.BeforeAndAfterAll import org.apache.spark.sql.execution.datasources.orc.OrcOptions import org.apache.spark.sql.execution.datasources.parquet.{ParquetOptions, ParquetTest} From 5244aafc2d7945c11c96398b8d5b752b45fd148c Mon Sep 17 00:00:00 2001 From: Xianjin YE Date: Tue, 2 Jan 2018 23:30:38 +0800 Subject: [PATCH 14/55] [SPARK-22897][CORE] Expose stageAttemptId in TaskContext ## What changes were proposed in this pull request? stageAttemptId added in TaskContext and corresponding construction modification ## How was this patch tested? Added a new test in TaskContextSuite, two cases are tested: 1. Normal case without failure 2. Exception case with resubmitted stages Link to [SPARK-22897](https://issues.apache.org/jira/browse/SPARK-22897) Author: Xianjin YE Closes #20082 from advancedxy/SPARK-22897. (cherry picked from commit a6fc300e91273230e7134ac6db95ccb4436c6f8f) Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/TaskContext.scala | 9 +++++- .../org/apache/spark/TaskContextImpl.scala | 5 ++-- .../org/apache/spark/scheduler/Task.scala | 1 + .../spark/JavaTaskContextCompileCheck.java | 2 ++ .../scala/org/apache/spark/ShuffleSuite.scala | 6 ++-- .../spark/memory/MemoryTestingUtils.scala | 1 + .../spark/scheduler/TaskContextSuite.scala | 29 +++++++++++++++++-- .../spark/storage/BlockInfoManagerSuite.scala | 2 +- project/MimaExcludes.scala | 3 ++ .../UnsafeFixedWidthAggregationMapSuite.scala | 1 + .../UnsafeKVExternalSorterSuite.scala | 1 + .../execution/UnsafeRowSerializerSuite.scala | 2 +- .../SortBasedAggregationStoreSuite.scala | 3 +- 13 files changed, 54 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/TaskContext.scala b/core/src/main/scala/org/apache/spark/TaskContext.scala index 0b87cd503d4fa..69739745aa6cf 100644 --- a/core/src/main/scala/org/apache/spark/TaskContext.scala +++ b/core/src/main/scala/org/apache/spark/TaskContext.scala @@ -66,7 +66,7 @@ object TaskContext { * An empty task context that does not represent an actual task. This is only used in tests. */ private[spark] def empty(): TaskContextImpl = { - new TaskContextImpl(0, 0, 0, 0, null, new Properties, null) + new TaskContextImpl(0, 0, 0, 0, 0, null, new Properties, null) } } @@ -150,6 +150,13 @@ abstract class TaskContext extends Serializable { */ def stageId(): Int + /** + * How many times the stage that this task belongs to has been attempted. The first stage attempt + * will be assigned stageAttemptNumber = 0, and subsequent attempts will have increasing attempt + * numbers. + */ + def stageAttemptNumber(): Int + /** * The ID of the RDD partition that is computed by this task. */ diff --git a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala index 01d8973e1bb06..cccd3ea457ba4 100644 --- a/core/src/main/scala/org/apache/spark/TaskContextImpl.scala +++ b/core/src/main/scala/org/apache/spark/TaskContextImpl.scala @@ -41,8 +41,9 @@ import org.apache.spark.util._ * `TaskMetrics` & `MetricsSystem` objects are not thread safe. */ private[spark] class TaskContextImpl( - val stageId: Int, - val partitionId: Int, + override val stageId: Int, + override val stageAttemptNumber: Int, + override val partitionId: Int, override val taskAttemptId: Long, override val attemptNumber: Int, override val taskMemoryManager: TaskMemoryManager, diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala index 7767ef1803a06..f536fc2a5f0a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala @@ -79,6 +79,7 @@ private[spark] abstract class Task[T]( SparkEnv.get.blockManager.registerTask(taskAttemptId) context = new TaskContextImpl( stageId, + stageAttemptId, // stageAttemptId and stageAttemptNumber are semantically equal partitionId, taskAttemptId, attemptNumber, diff --git a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java index 94f5805853e1e..f8e233a05a447 100644 --- a/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java +++ b/core/src/test/java/test/org/apache/spark/JavaTaskContextCompileCheck.java @@ -38,6 +38,7 @@ public static void test() { tc.attemptNumber(); tc.partitionId(); tc.stageId(); + tc.stageAttemptNumber(); tc.taskAttemptId(); } @@ -51,6 +52,7 @@ public void onTaskCompletion(TaskContext context) { context.isCompleted(); context.isInterrupted(); context.stageId(); + context.stageAttemptNumber(); context.partitionId(); context.addTaskCompletionListener(this); } diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala index 3931d53b4ae0a..ced5a06516f75 100644 --- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala +++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala @@ -363,14 +363,14 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC // first attempt -- its successful val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 0L, 0, taskMemoryManager, new Properties, metricsSystem)) val data1 = (1 to 10).map { x => x -> x} // second attempt -- also successful. We'll write out different data, // just to simulate the fact that the records may get written differently // depending on what gets spilled, what gets combined, etc. val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0, - new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(0, 0, 0, 1L, 0, taskMemoryManager, new Properties, metricsSystem)) val data2 = (11 to 20).map { x => x -> x} // interleave writes of both attempts -- we want to test that both attempts can occur @@ -398,7 +398,7 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC } val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1, - new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) + new TaskContextImpl(1, 0, 0, 2L, 0, taskMemoryManager, new Properties, metricsSystem)) val readData = reader.read().toIndexedSeq assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq) diff --git a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala index 362cd861cc248..dcf89e4f75acf 100644 --- a/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala +++ b/core/src/test/scala/org/apache/spark/memory/MemoryTestingUtils.scala @@ -29,6 +29,7 @@ object MemoryTestingUtils { val taskMemoryManager = new TaskMemoryManager(env.memoryManager, 0) new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 0, attemptNumber = 0, diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala index a1d9085fa085d..aa9c36c0aaacb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskContextSuite.scala @@ -29,6 +29,7 @@ import org.apache.spark.memory.TaskMemoryManager import org.apache.spark.metrics.source.JvmSource import org.apache.spark.network.util.JavaUtils import org.apache.spark.rdd.RDD +import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.util._ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSparkContext { @@ -158,6 +159,30 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark assert(attemptIdsWithFailedTask.toSet === Set(0, 1)) } + test("TaskContext.stageAttemptNumber getter") { + sc = new SparkContext("local[1,2]", "test") + + // Check stageAttemptNumbers are 0 for initial stage + val stageAttemptNumbers = sc.parallelize(Seq(1, 2), 2).mapPartitions { _ => + Seq(TaskContext.get().stageAttemptNumber()).iterator + }.collect() + assert(stageAttemptNumbers.toSet === Set(0)) + + // Check stageAttemptNumbers that are resubmitted when tasks have FetchFailedException + val stageAttemptNumbersWithFailedStage = + sc.parallelize(Seq(1, 2, 3, 4), 4).repartition(1).mapPartitions { _ => + val stageAttemptNumber = TaskContext.get().stageAttemptNumber() + if (stageAttemptNumber < 2) { + // Throw FetchFailedException to explicitly trigger stage resubmission. A normal exception + // will only trigger task resubmission in the same stage. + throw new FetchFailedException(null, 0, 0, 0, "Fake") + } + Seq(stageAttemptNumber).iterator + }.collect() + + assert(stageAttemptNumbersWithFailedStage.toSet === Set(2)) + } + test("accumulators are updated on exception failures") { // This means use 1 core and 4 max task failures sc = new SparkContext("local[1,4]", "test") @@ -190,7 +215,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.empty val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, @@ -213,7 +238,7 @@ class TaskContextSuite extends SparkFunSuite with BeforeAndAfter with LocalSpark // accumulator updates from it. val taskMetrics = TaskMetrics.registered val task = new Task[Int](0, 0, 0) { - context = new TaskContextImpl(0, 0, 0L, 0, + context = new TaskContextImpl(0, 0, 0, 0L, 0, new TaskMemoryManager(SparkEnv.get.memoryManager, 0L), new Properties, SparkEnv.get.metricsSystem, diff --git a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala index 917db766f7f11..9c0699bc981f8 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockInfoManagerSuite.scala @@ -62,7 +62,7 @@ class BlockInfoManagerSuite extends SparkFunSuite with BeforeAndAfterEach { private def withTaskId[T](taskAttemptId: Long)(block: => T): T = { try { TaskContext.setTaskContext( - new TaskContextImpl(0, 0, taskAttemptId, 0, null, new Properties, null)) + new TaskContextImpl(0, 0, 0, taskAttemptId, 0, null, new Properties, null)) block } finally { TaskContext.unset() diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 81584af6813ea..3b452f35c5ec1 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -36,6 +36,9 @@ object MimaExcludes { // Exclude rules for 2.3.x lazy val v23excludes = v22excludes ++ Seq( + // [SPARK-22897] Expose stageAttemptId in TaskContext + ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.TaskContext.stageAttemptNumber"), + // SPARK-22789: Map-only continuous processing execution ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$8"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.streaming.StreamingQueryManager.startQuery$default$6"), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala index 232c1beae7998..3e31d22e15c0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala @@ -70,6 +70,7 @@ class UnsafeFixedWidthAggregationMapSuite TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = Random.nextInt(10000), attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala index 604502f2a57d0..6af9f8b77f8d3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala @@ -116,6 +116,7 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext { val taskMemMgr = new TaskMemoryManager(memoryManager, 0) TaskContext.setTaskContext(new TaskContextImpl( stageId = 0, + stageAttemptNumber = 0, partitionId = 0, taskAttemptId = 98456, attemptNumber = 0, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala index dff88ce7f1b9a..a3ae93810aa3c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeRowSerializerSuite.scala @@ -114,7 +114,7 @@ class UnsafeRowSerializerSuite extends SparkFunSuite with LocalSparkContext { (i, converter(Row(i))) } val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0) - val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, new Properties, null) + val taskContext = new TaskContextImpl(0, 0, 0, 0, 0, taskMemoryManager, new Properties, null) val sorter = new ExternalSorter[Int, UnsafeRow, UnsafeRow]( taskContext, diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala index 10f1ee279bedf..3fad7dfddadcc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationStoreSuite.scala @@ -35,7 +35,8 @@ class SortBasedAggregationStoreSuite extends SparkFunSuite with LocalSparkConte val conf = new SparkConf() sc = new SparkContext("local[2, 4]", "test", conf) val taskManager = new TaskMemoryManager(new TestMemoryManager(conf), 0) - TaskContext.setTaskContext(new TaskContextImpl(0, 0, 0, 0, taskManager, new Properties, null)) + TaskContext.setTaskContext( + new TaskContextImpl(0, 0, 0, 0, 0, taskManager, new Properties, null)) } override def afterAll(): Unit = TaskContext.unset() From b96a2132413937c013e1099be3ec4bc420c947fd Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Wed, 3 Jan 2018 21:40:51 +0800 Subject: [PATCH 15/55] [SPARK-22938] Assert that SQLConf.get is accessed only on the driver. ## What changes were proposed in this pull request? Assert if code tries to access SQLConf.get on executor. This can lead to hard to detect bugs, where the executor will read fallbackConf, falling back to default config values, ignoring potentially changed non-default configs. If a config is to be passed to executor code, it needs to be read on the driver, and passed explicitly. ## How was this patch tested? Check in existing tests. Author: Juliusz Sompolski Closes #20136 from juliuszsompolski/SPARK-22938. (cherry picked from commit 247a08939d58405aef39b2a4e7773aa45474ad12) Signed-off-by: Wenchen Fan --- .../scala/org/apache/spark/sql/internal/SQLConf.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 4f77c54a7af57..80cdc61484c0f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -27,11 +27,13 @@ import scala.util.matching.Regex import org.apache.hadoop.fs.Path +import org.apache.spark.{SparkContext, SparkEnv} import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.network.util.ByteUnit import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.expressions.codegen.CodeGenerator +import org.apache.spark.util.Utils //////////////////////////////////////////////////////////////////////////////////////////////////// // This file defines the configuration options for Spark SQL. @@ -70,7 +72,7 @@ object SQLConf { * Default config. Only used when there is no active SparkSession for the thread. * See [[get]] for more information. */ - private val fallbackConf = new ThreadLocal[SQLConf] { + private lazy val fallbackConf = new ThreadLocal[SQLConf] { override def initialValue: SQLConf = new SQLConf } @@ -1087,6 +1089,12 @@ object SQLConf { class SQLConf extends Serializable with Logging { import SQLConf._ + if (Utils.isTesting && SparkEnv.get != null) { + // assert that we're only accessing it on the driver. + assert(SparkEnv.get.executorId == SparkContext.DRIVER_IDENTIFIER, + "SQLConf should only be created and accessed on the driver.") + } + /** Only low degree of contention is expected for conf, thus NOT using ConcurrentHashMap. */ @transient protected[spark] val settings = java.util.Collections.synchronizedMap( new java.util.HashMap[String, String]()) From a05e85ecb76091567a26a3a14ad0879b4728addc Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Wed, 3 Jan 2018 22:09:30 +0800 Subject: [PATCH 16/55] [SPARK-22934][SQL] Make optional clauses order insensitive for CREATE TABLE SQL statement ## What changes were proposed in this pull request? Currently, our CREATE TABLE syntax require the EXACT order of clauses. It is pretty hard to remember the exact order. Thus, this PR is to make optional clauses order insensitive for `CREATE TABLE` SQL statement. ``` CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1 col_type1 [COMMENT col_comment1], ...)] USING datasource [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [OPTIONS (key1=val1, key2=val2, ...)] [PARTITIONED BY (col_name1, col_name2, ...)] [CLUSTERED BY (col_name3, col_name4, ...) INTO num_buckets BUCKETS] [LOCATION path] [COMMENT table_comment] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` The same idea is also applicable to Create Hive Table. ``` CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name [(col_name1[:] col_type1 [COMMENT col_comment1], ...)] [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] [AS select_statement] ``` The proposal is to make the following clauses order insensitive. ``` [COMMENT table_comment] [PARTITIONED BY (col_name2[:] col_type2 [COMMENT col_comment2], ...)] [ROW FORMAT row_format] [STORED AS file_format] [LOCATION path] [TBLPROPERTIES (key1=val1, key2=val2, ...)] ``` ## How was this patch tested? Added test cases Author: gatorsmile Closes #20133 from gatorsmile/createDataSourceTableDDL. (cherry picked from commit 1a87a1609c4d2c9027a2cf669ea3337b89f61fb6) Signed-off-by: gatorsmile --- .../spark/sql/catalyst/parser/SqlBase.g4 | 24 +- .../sql/catalyst/parser/ParserUtils.scala | 9 + .../spark/sql/execution/SparkSqlParser.scala | 81 +++++-- .../execution/command/DDLParserSuite.scala | 220 ++++++++++++++---- .../sql/execution/command/DDLSuite.scala | 2 +- .../sql/hive/execution/HiveDDLSuite.scala | 13 +- .../sql/hive/execution/SQLQuerySuite.scala | 124 +++++----- 7 files changed, 335 insertions(+), 138 deletions(-) diff --git a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 index 6fe995f650d55..6daf01d98426c 100644 --- a/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 +++ b/sql/catalyst/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBase.g4 @@ -73,18 +73,22 @@ statement | ALTER DATABASE identifier SET DBPROPERTIES tablePropertyList #setDatabaseProperties | DROP DATABASE (IF EXISTS)? identifier (RESTRICT | CASCADE)? #dropDatabase | createTableHeader ('(' colTypeList ')')? tableProvider - (OPTIONS options=tablePropertyList)? - (PARTITIONED BY partitionColumnNames=identifierList)? - bucketSpec? locationSpec? - (COMMENT comment=STRING)? - (TBLPROPERTIES tableProps=tablePropertyList)? + ((OPTIONS options=tablePropertyList) | + (PARTITIONED BY partitionColumnNames=identifierList) | + bucketSpec | + locationSpec | + (COMMENT comment=STRING) | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createTable | createTableHeader ('(' columns=colTypeList ')')? - (COMMENT comment=STRING)? - (PARTITIONED BY '(' partitionColumns=colTypeList ')')? - bucketSpec? skewSpec? - rowFormat? createFileFormat? locationSpec? - (TBLPROPERTIES tablePropertyList)? + ((COMMENT comment=STRING) | + (PARTITIONED BY '(' partitionColumns=colTypeList ')') | + bucketSpec | + skewSpec | + rowFormat | + createFileFormat | + locationSpec | + (TBLPROPERTIES tableProps=tablePropertyList))* (AS? query)? #createHiveTable | CREATE TABLE (IF NOT EXISTS)? target=tableIdentifier LIKE source=tableIdentifier locationSpec? #createTableLike diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala index 9b127f91648e6..89347f4b1f7bf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/ParserUtils.scala @@ -16,6 +16,8 @@ */ package org.apache.spark.sql.catalyst.parser +import java.util + import scala.collection.mutable.StringBuilder import org.antlr.v4.runtime.{ParserRuleContext, Token} @@ -39,6 +41,13 @@ object ParserUtils { throw new ParseException(s"Operation not allowed: $message", ctx) } + def checkDuplicateClauses[T]( + nodes: util.List[T], clauseName: String, ctx: ParserRuleContext): Unit = { + if (nodes.size() > 1) { + throw new ParseException(s"Found duplicate clauses: $clauseName", ctx) + } + } + /** Check if duplicate keys exist in a set of key-value pairs. */ def checkDuplicateKeys[T](keyPairs: Seq[(String, T)], ctx: ParserRuleContext): Unit = { keyPairs.groupBy(_._1).filter(_._2.size > 1).foreach { case (key, _) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index 29b584b55972c..d3cfd2a1ffbf2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -383,16 +383,19 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [TEMPORARY] TABLE [IF NOT EXISTS] [db_name.]table_name * USING table_provider - * [OPTIONS table_property_list] - * [PARTITIONED BY (col_name, col_name, ...)] - * [CLUSTERED BY (col_name, col_name, ...) - * [SORTED BY (col_name [ASC|DESC], ...)] - * INTO num_buckets BUCKETS - * ] - * [LOCATION path] - * [COMMENT table_comment] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [[AS] select_statement]; + * + * create_table_clauses (order insensitive): + * [OPTIONS table_property_list] + * [PARTITIONED BY (col_name, col_name, ...)] + * [CLUSTERED BY (col_name, col_name, ...) + * [SORTED BY (col_name [ASC|DESC], ...)] + * INTO num_buckets BUCKETS + * ] + * [LOCATION path] + * [COMMENT table_comment] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateTable(ctx: CreateTableContext): LogicalPlan = withOrigin(ctx) { @@ -400,6 +403,14 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { if (external) { operationNotAllowed("CREATE EXTERNAL TABLE ... USING", ctx) } + + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.OPTIONS, "OPTIONS", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val options = Option(ctx.options).map(visitPropertyKeyValues).getOrElse(Map.empty) val provider = ctx.tableProvider.qualifiedName.getText val schema = Option(ctx.colTypeList()).map(createSchema) @@ -408,9 +419,9 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { .map(visitIdentifierList(_).toArray) .getOrElse(Array.empty[String]) val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) val storage = DataSource.buildStorageFormatFromOptions(options) if (location.isDefined && storage.locationUri.isDefined) { @@ -1087,13 +1098,16 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { * {{{ * CREATE [EXTERNAL] TABLE [IF NOT EXISTS] [db_name.]table_name * [(col1[:] data_type [COMMENT col_comment], ...)] - * [COMMENT table_comment] - * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] - * [ROW FORMAT row_format] - * [STORED AS file_format] - * [LOCATION path] - * [TBLPROPERTIES (property_name=property_value, ...)] + * create_table_clauses * [AS select_statement]; + * + * create_table_clauses (order insensitive): + * [COMMENT table_comment] + * [PARTITIONED BY (col2[:] data_type [COMMENT col_comment], ...)] + * [ROW FORMAT row_format] + * [STORED AS file_format] + * [LOCATION path] + * [TBLPROPERTIES (property_name=property_value, ...)] * }}} */ override def visitCreateHiveTable(ctx: CreateHiveTableContext): LogicalPlan = withOrigin(ctx) { @@ -1104,15 +1118,23 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { "CREATE TEMPORARY TABLE is not supported yet. " + "Please use CREATE TEMPORARY VIEW as an alternative.", ctx) } - if (ctx.skewSpec != null) { + if (ctx.skewSpec.size > 0) { operationNotAllowed("CREATE TABLE ... SKEWED BY", ctx) } + checkDuplicateClauses(ctx.TBLPROPERTIES, "TBLPROPERTIES", ctx) + checkDuplicateClauses(ctx.PARTITIONED, "PARTITIONED BY", ctx) + checkDuplicateClauses(ctx.COMMENT, "COMMENT", ctx) + checkDuplicateClauses(ctx.bucketSpec(), "CLUSTERED BY", ctx) + checkDuplicateClauses(ctx.createFileFormat, "STORED AS/BY", ctx) + checkDuplicateClauses(ctx.rowFormat, "ROW FORMAT", ctx) + checkDuplicateClauses(ctx.locationSpec, "LOCATION", ctx) + val dataCols = Option(ctx.columns).map(visitColTypeList).getOrElse(Nil) val partitionCols = Option(ctx.partitionColumns).map(visitColTypeList).getOrElse(Nil) - val properties = Option(ctx.tablePropertyList).map(visitPropertyKeyValues).getOrElse(Map.empty) + val properties = Option(ctx.tableProps).map(visitPropertyKeyValues).getOrElse(Map.empty) val selectQuery = Option(ctx.query).map(plan) - val bucketSpec = Option(ctx.bucketSpec()).map(visitBucketSpec) + val bucketSpec = ctx.bucketSpec().asScala.headOption.map(visitBucketSpec) // Note: Hive requires partition columns to be distinct from the schema, so we need // to include the partition columns here explicitly @@ -1120,12 +1142,12 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { // Storage format val defaultStorage = HiveSerDe.getDefaultStorage(conf) - validateRowFormatFileFormat(ctx.rowFormat, ctx.createFileFormat, ctx) - val fileStorage = Option(ctx.createFileFormat).map(visitCreateFileFormat) + validateRowFormatFileFormat(ctx.rowFormat.asScala, ctx.createFileFormat.asScala, ctx) + val fileStorage = ctx.createFileFormat.asScala.headOption.map(visitCreateFileFormat) .getOrElse(CatalogStorageFormat.empty) - val rowStorage = Option(ctx.rowFormat).map(visitRowFormat) + val rowStorage = ctx.rowFormat.asScala.headOption.map(visitRowFormat) .getOrElse(CatalogStorageFormat.empty) - val location = Option(ctx.locationSpec).map(visitLocationSpec) + val location = ctx.locationSpec.asScala.headOption.map(visitLocationSpec) // If we are creating an EXTERNAL table, then the LOCATION field is required if (external && location.isEmpty) { operationNotAllowed("CREATE EXTERNAL TABLE must be accompanied by LOCATION", ctx) @@ -1180,7 +1202,7 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { ctx) } - val hasStorageProperties = (ctx.createFileFormat != null) || (ctx.rowFormat != null) + val hasStorageProperties = (ctx.createFileFormat.size != 0) || (ctx.rowFormat.size != 0) if (conf.convertCTAS && !hasStorageProperties) { // At here, both rowStorage.serdeProperties and fileStorage.serdeProperties // are empty Maps. @@ -1366,6 +1388,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder(conf) { } } + private def validateRowFormatFileFormat( + rowFormatCtx: Seq[RowFormatContext], + createFileFormatCtx: Seq[CreateFileFormatContext], + parentCtx: ParserRuleContext): Unit = { + if (rowFormatCtx.size == 1 && createFileFormatCtx.size == 1) { + validateRowFormatFileFormat(rowFormatCtx.head, createFileFormatCtx.head, parentCtx) + } + } + /** * Create or replace a view. This creates a [[CreateViewCommand]] command. * diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala index eb7c33590b602..2b1aea08b1223 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLParserSuite.scala @@ -54,6 +54,13 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + private def intercept(sqlCommand: String, messages: String*): Unit = { + val e = intercept[ParseException](parser.parsePlan(sqlCommand)).getMessage + messages.foreach { message => + assert(e.contains(message)) + } + } + private def parseAs[T: ClassTag](query: String): T = { parser.parsePlan(query) match { case t: T => t @@ -494,6 +501,37 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Duplicate clauses - create table") { + def createTableHeader(duplicateClause: String, isNative: Boolean): String = { + val fileFormat = if (isNative) "USING parquet" else "STORED AS parquet" + s"CREATE TABLE my_tab(a INT, b STRING) $fileFormat $duplicateClause $duplicateClause" + } + + Seq(true, false).foreach { isNative => + intercept(createTableHeader("TBLPROPERTIES('test' = 'test2')", isNative), + "Found duplicate clauses: TBLPROPERTIES") + intercept(createTableHeader("LOCATION '/tmp/file'", isNative), + "Found duplicate clauses: LOCATION") + intercept(createTableHeader("COMMENT 'a table'", isNative), + "Found duplicate clauses: COMMENT") + intercept(createTableHeader("CLUSTERED BY(b) INTO 256 BUCKETS", isNative), + "Found duplicate clauses: CLUSTERED BY") + } + + // Only for native data source tables + intercept(createTableHeader("PARTITIONED BY (b)", isNative = true), + "Found duplicate clauses: PARTITIONED BY") + + // Only for Hive serde tables + intercept(createTableHeader("PARTITIONED BY (k int)", isNative = false), + "Found duplicate clauses: PARTITIONED BY") + intercept(createTableHeader("STORED AS parquet", isNative = false), + "Found duplicate clauses: STORED AS/BY") + intercept( + createTableHeader("ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe'", isNative = false), + "Found duplicate clauses: ROW FORMAT") + } + test("create table - with location") { val v1 = "CREATE TABLE my_tab(a INT, b STRING) USING parquet LOCATION '/tmp/file'" @@ -1153,38 +1191,119 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { } } + test("Test CTAS against data source tables") { + val s1 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s2 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |LOCATION '/user/external/page_view' + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE TABLE IF NOT EXISTS mydb.page_view + |USING parquet + |COMMENT 'This is the staging page view table' + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.provider == Some("parquet")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } + } + test("Test CTAS #1") { val s1 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |STORED AS RCFILE |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s1) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - assert(desc.comment == Some("This is the staging page view table")) - // TODO will be SQLText - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) - assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) - assert(desc.storage.serde == - Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |AS SELECT * FROM src + """.stripMargin + + val s3 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |LOCATION '/user/external/page_view' + |STORED AS RCFILE + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + checkParsing(s3) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + assert(desc.comment == Some("This is the staging page view table")) + // TODO will be SQLText + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.inputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileInputFormat")) + assert(desc.storage.outputFormat == Some("org.apache.hadoop.hive.ql.io.RCFileOutputFormat")) + assert(desc.storage.serde == + Some("org.apache.hadoop.hive.serde2.columnar.LazyBinaryColumnarSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #2") { - val s2 = - """CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + val s1 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view |COMMENT 'This is the staging page view table' |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' | STORED AS @@ -1192,26 +1311,45 @@ class DDLParserSuite extends PlanTest with SharedSQLContext { | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' |LOCATION '/user/external/page_view' |TBLPROPERTIES ('p1'='v1', 'p2'='v2') - |AS SELECT * FROM src""".stripMargin + |AS SELECT * FROM src + """.stripMargin - val (desc, exists) = extractTableDesc(s2) - assert(exists) - assert(desc.identifier.database == Some("mydb")) - assert(desc.identifier.table == "page_view") - assert(desc.tableType == CatalogTableType.EXTERNAL) - assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) - assert(desc.schema.isEmpty) // will be populated later when the table is actually created - // TODO will be SQLText - assert(desc.comment == Some("This is the staging page view table")) - assert(desc.viewText.isEmpty) - assert(desc.viewDefaultDatabase.isEmpty) - assert(desc.viewQueryColumnNames.isEmpty) - assert(desc.partitionColumnNames.isEmpty) - assert(desc.storage.properties == Map()) - assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) - assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) - assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) - assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + val s2 = + """ + |CREATE EXTERNAL TABLE IF NOT EXISTS mydb.page_view + |LOCATION '/user/external/page_view' + |TBLPROPERTIES ('p1'='v1', 'p2'='v2') + |ROW FORMAT SERDE 'parquet.hive.serde.ParquetHiveSerDe' + | STORED AS + | INPUTFORMAT 'parquet.hive.DeprecatedParquetInputFormat' + | OUTPUTFORMAT 'parquet.hive.DeprecatedParquetOutputFormat' + |COMMENT 'This is the staging page view table' + |AS SELECT * FROM src + """.stripMargin + + checkParsing(s1) + checkParsing(s2) + + def checkParsing(sql: String): Unit = { + val (desc, exists) = extractTableDesc(sql) + assert(exists) + assert(desc.identifier.database == Some("mydb")) + assert(desc.identifier.table == "page_view") + assert(desc.tableType == CatalogTableType.EXTERNAL) + assert(desc.storage.locationUri == Some(new URI("/user/external/page_view"))) + assert(desc.schema.isEmpty) // will be populated later when the table is actually created + // TODO will be SQLText + assert(desc.comment == Some("This is the staging page view table")) + assert(desc.viewText.isEmpty) + assert(desc.viewDefaultDatabase.isEmpty) + assert(desc.viewQueryColumnNames.isEmpty) + assert(desc.partitionColumnNames.isEmpty) + assert(desc.storage.properties == Map()) + assert(desc.storage.inputFormat == Some("parquet.hive.DeprecatedParquetInputFormat")) + assert(desc.storage.outputFormat == Some("parquet.hive.DeprecatedParquetOutputFormat")) + assert(desc.storage.serde == Some("parquet.hive.serde.ParquetHiveSerDe")) + assert(desc.properties == Map("p1" -> "v1", "p2" -> "v2")) + } } test("Test CTAS #3") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index fdb9b2f51f9cb..591510c1d8283 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1971,8 +1971,8 @@ abstract class DDLSuite extends QueryTest with SQLTestUtils { s""" |CREATE TABLE t(a int, b int, c int, d int) |USING parquet - |PARTITIONED BY(a, b) |LOCATION "${dir.toURI}" + |PARTITIONED BY(a, b) """.stripMargin) spark.sql("INSERT INTO TABLE t PARTITION(a=1, b=2) SELECT 3, 4") checkAnswer(spark.table("t"), Row(3, 4, 1, 2) :: Nil) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index f2e0c695ca38b..65be244418670 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -875,12 +875,13 @@ class HiveDDLSuite test("desc table for Hive table - bucketed + sorted table") { withTable("tbl") { - sql(s""" - CREATE TABLE tbl (id int, name string) - PARTITIONED BY (ds string) - CLUSTERED BY(id) - SORTED BY(id, name) INTO 1024 BUCKETS - """) + sql( + s""" + |CREATE TABLE tbl (id int, name string) + |CLUSTERED BY(id) + |SORTED BY(id, name) INTO 1024 BUCKETS + |PARTITIONED BY (ds string) + """.stripMargin) val x = sql("DESC FORMATTED tbl").collect() assert(x.containsSlice( diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index 07ae3ae945848..47adc77a52d51 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -461,51 +461,55 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { } test("CTAS without serde without location") { - val originalConf = sessionState.conf.convertCTAS - - setConf(SQLConf.CONVERT_CTAS, true) - - val defaultDataSource = sessionState.conf.defaultDataSourceName - try { - sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - val message = intercept[AnalysisException] { + withSQLConf(SQLConf.CONVERT_CTAS.key -> "true") { + val defaultDataSource = sessionState.conf.defaultDataSourceName + withTable("ctas1") { sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - }.getMessage - assert(message.contains("already exists")) - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + sql("CREATE TABLE IF NOT EXISTS ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + val message = intercept[AnalysisException] { + sql("CREATE TABLE ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + }.getMessage + assert(message.contains("already exists")) + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } // Specifying database name for query can be converted to data source write path // is not allowed right now. - sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE default.ctas1 AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = true, defaultDataSource) + } - sql("CREATE TABLE ctas1 stored as textfile" + + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as textfile" + " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text") - sql("DROP TABLE ctas1") + checkRelation("ctas1", isDataSourceTable = false, "text") + } - sql("CREATE TABLE ctas1 stored as sequencefile" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as sequencefile" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "sequence") + } - sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as rcfile AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "rcfile") + } - sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "orc") - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql("CREATE TABLE ctas1 stored as orc AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation("ctas1", isDataSourceTable = false, "orc") + } - sql("CREATE TABLE ctas1 stored as parquet AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "parquet") - sql("DROP TABLE ctas1") - } finally { - setConf(SQLConf.CONVERT_CTAS, originalConf) - sql("DROP TABLE IF EXISTS ctas1") + withTable("ctas1") { + sql( + """ + |CREATE TABLE ctas1 stored as parquet + |AS SELECT key k, value FROM src ORDER BY k, value + """.stripMargin) + checkRelation("ctas1", isDataSourceTable = false, "parquet") + } } } @@ -539,30 +543,40 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { val defaultDataSource = sessionState.conf.defaultDataSourceName val tempLocation = dir.toURI.getPath.stripSuffix("/") - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c1")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c1'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c1")) + } - sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", true, defaultDataSource, Some(s"file:$tempLocation/c2")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 LOCATION 'file:$tempLocation/c2'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = true, defaultDataSource, Some(s"file:$tempLocation/c2")) + } - sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "text", Some(s"file:$tempLocation/c3")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as textfile LOCATION 'file:$tempLocation/c3'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "text", Some(s"file:$tempLocation/c3")) + } - sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "sequence", Some(s"file:$tempLocation/c4")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as sequenceFile LOCATION 'file:$tempLocation/c4'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "sequence", Some(s"file:$tempLocation/c4")) + } - sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + - " AS SELECT key k, value FROM src ORDER BY k, value") - checkRelation("ctas1", false, "rcfile", Some(s"file:$tempLocation/c5")) - sql("DROP TABLE ctas1") + withTable("ctas1") { + sql(s"CREATE TABLE ctas1 stored as rcfile LOCATION 'file:$tempLocation/c5'" + + " AS SELECT key k, value FROM src ORDER BY k, value") + checkRelation( + "ctas1", isDataSourceTable = false, "rcfile", Some(s"file:$tempLocation/c5")) + } } } } From b96248862589bae1ddcdb14ce4c802789a001306 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 3 Jan 2018 22:18:13 +0800 Subject: [PATCH 17/55] [SPARK-20236][SQL] dynamic partition overwrite ## What changes were proposed in this pull request? When overwriting a partitioned table with dynamic partition columns, the behavior is different between data source and hive tables. data source table: delete all partition directories that match the static partition values provided in the insert statement. hive table: only delete partition directories which have data written into it This PR adds a new config to make users be able to choose hive's behavior. ## How was this patch tested? new tests Author: Wenchen Fan Closes #18714 from cloud-fan/overwrite-partition. (cherry picked from commit a66fe36cee9363b01ee70e469f1c968f633c5713) Signed-off-by: gatorsmile --- .../internal/io/FileCommitProtocol.scala | 25 ++++-- .../io/HadoopMapReduceCommitProtocol.scala | 75 ++++++++++++++---- .../apache/spark/sql/internal/SQLConf.scala | 21 +++++ .../InsertIntoHadoopFsRelationCommand.scala | 20 ++++- .../SQLHadoopMapReduceCommitProtocol.scala | 10 ++- .../spark/sql/sources/InsertSuite.scala | 78 +++++++++++++++++++ 6 files changed, 200 insertions(+), 29 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala index 50f51e1af4530..6d0059b6a0272 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/FileCommitProtocol.scala @@ -28,8 +28,9 @@ import org.apache.spark.util.Utils * * 1. Implementations must be serializable, as the committer instance instantiated on the driver * will be used for tasks on executors. - * 2. Implementations should have a constructor with 2 arguments: - * (jobId: String, path: String) + * 2. Implementations should have a constructor with 2 or 3 arguments: + * (jobId: String, path: String) or + * (jobId: String, path: String, dynamicPartitionOverwrite: Boolean) * 3. A committer should not be reused across multiple Spark jobs. * * The proper call sequence is: @@ -139,10 +140,22 @@ object FileCommitProtocol { /** * Instantiates a FileCommitProtocol using the given className. */ - def instantiate(className: String, jobId: String, outputPath: String) - : FileCommitProtocol = { + def instantiate( + className: String, + jobId: String, + outputPath: String, + dynamicPartitionOverwrite: Boolean = false): FileCommitProtocol = { val clazz = Utils.classForName(className).asInstanceOf[Class[FileCommitProtocol]] - val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) - ctor.newInstance(jobId, outputPath) + // First try the constructor with arguments (jobId: String, outputPath: String, + // dynamicPartitionOverwrite: Boolean). + // If that doesn't exist, try the one with (jobId: string, outputPath: String). + try { + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String], classOf[Boolean]) + ctor.newInstance(jobId, outputPath, dynamicPartitionOverwrite.asInstanceOf[java.lang.Boolean]) + } catch { + case _: NoSuchMethodException => + val ctor = clazz.getDeclaredConstructor(classOf[String], classOf[String]) + ctor.newInstance(jobId, outputPath) + } } } diff --git a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala index 95c99d29c3a9c..6d20ef1f98a3c 100644 --- a/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala +++ b/core/src/main/scala/org/apache/spark/internal/io/HadoopMapReduceCommitProtocol.scala @@ -39,8 +39,19 @@ import org.apache.spark.mapred.SparkHadoopMapRedUtil * * @param jobId the job's or stage's id * @param path the job's output path, or null if committer acts as a noop + * @param dynamicPartitionOverwrite If true, Spark will overwrite partition directories at runtime + * dynamically, i.e., we first write files under a staging + * directory with partition path, e.g. + * /path/to/staging/a=1/b=1/xxx.parquet. When committing the job, + * we first clean up the corresponding partition directories at + * destination path, e.g. /path/to/destination/a=1/b=1, and move + * files from staging directory to the corresponding partition + * directories under destination path. */ -class HadoopMapReduceCommitProtocol(jobId: String, path: String) +class HadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) extends FileCommitProtocol with Serializable with Logging { import FileCommitProtocol._ @@ -67,9 +78,17 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) @transient private var addedAbsPathFiles: mutable.Map[String, String] = null /** - * The staging directory for all files committed with absolute output paths. + * Tracks partitions with default path that have new files written into them by this task, + * e.g. a=1/b=2. Files under these partitions will be saved into staging directory and moved to + * destination directory at the end, if `dynamicPartitionOverwrite` is true. */ - private def absPathStagingDir: Path = new Path(path, "_temporary-" + jobId) + @transient private var partitionPaths: mutable.Set[String] = null + + /** + * The staging directory of this write job. Spark uses it to deal with files with absolute output + * path, or writing data into partitioned directory with dynamicPartitionOverwrite=true. + */ + private def stagingDir = new Path(path, ".spark-staging-" + jobId) protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { val format = context.getOutputFormatClass.newInstance() @@ -85,11 +104,16 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) taskContext: TaskAttemptContext, dir: Option[String], ext: String): String = { val filename = getFilename(taskContext, ext) - val stagingDir: String = committer match { + val stagingDir: Path = committer match { + case _ if dynamicPartitionOverwrite => + assert(dir.isDefined, + "The dataset to be written must be partitioned when dynamicPartitionOverwrite is true.") + partitionPaths += dir.get + this.stagingDir // For FileOutputCommitter it has its own staging path called "work path". case f: FileOutputCommitter => - Option(f.getWorkPath).map(_.toString).getOrElse(path) - case _ => path + new Path(Option(f.getWorkPath).map(_.toString).getOrElse(path)) + case _ => new Path(path) } dir.map { d => @@ -106,8 +130,7 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) // Include a UUID here to prevent file collisions for one task writing to different dirs. // In principle we could include hash(absoluteDir) instead but this is simpler. - val tmpOutputPath = new Path( - absPathStagingDir, UUID.randomUUID().toString() + "-" + filename).toString + val tmpOutputPath = new Path(stagingDir, UUID.randomUUID().toString() + "-" + filename).toString addedAbsPathFiles(tmpOutputPath) = absOutputPath tmpOutputPath @@ -141,23 +164,42 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) override def commitJob(jobContext: JobContext, taskCommits: Seq[TaskCommitMessage]): Unit = { committer.commitJob(jobContext) - val filesToMove = taskCommits.map(_.obj.asInstanceOf[Map[String, String]]) - .foldLeft(Map[String, String]())(_ ++ _) - logDebug(s"Committing files staged for absolute locations $filesToMove") + if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) + val (allAbsPathFiles, allPartitionPaths) = + taskCommits.map(_.obj.asInstanceOf[(Map[String, String], Set[String])]).unzip + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + + val filesToMove = allAbsPathFiles.foldLeft(Map[String, String]())(_ ++ _) + logDebug(s"Committing files staged for absolute locations $filesToMove") + if (dynamicPartitionOverwrite) { + val absPartitionPaths = filesToMove.values.map(new Path(_).getParent).toSet + logDebug(s"Clean up absolute partition directories for overwriting: $absPartitionPaths") + absPartitionPaths.foreach(fs.delete(_, true)) + } for ((src, dst) <- filesToMove) { fs.rename(new Path(src), new Path(dst)) } - fs.delete(absPathStagingDir, true) + + if (dynamicPartitionOverwrite) { + val partitionPaths = allPartitionPaths.foldLeft(Set[String]())(_ ++ _) + logDebug(s"Clean up default partition directories for overwriting: $partitionPaths") + for (part <- partitionPaths) { + val finalPartPath = new Path(path, part) + fs.delete(finalPartPath, true) + fs.rename(new Path(stagingDir, part), finalPartPath) + } + } + + fs.delete(stagingDir, true) } } override def abortJob(jobContext: JobContext): Unit = { committer.abortJob(jobContext, JobStatus.State.FAILED) if (hasValidPath) { - val fs = absPathStagingDir.getFileSystem(jobContext.getConfiguration) - fs.delete(absPathStagingDir, true) + val fs = stagingDir.getFileSystem(jobContext.getConfiguration) + fs.delete(stagingDir, true) } } @@ -165,13 +207,14 @@ class HadoopMapReduceCommitProtocol(jobId: String, path: String) committer = setupCommitter(taskContext) committer.setupTask(taskContext) addedAbsPathFiles = mutable.Map[String, String]() + partitionPaths = mutable.Set[String]() } override def commitTask(taskContext: TaskAttemptContext): TaskCommitMessage = { val attemptId = taskContext.getTaskAttemptID SparkHadoopMapRedUtil.commitTask( committer, taskContext, attemptId.getJobID.getId, attemptId.getTaskID.getId) - new TaskCommitMessage(addedAbsPathFiles.toMap) + new TaskCommitMessage(addedAbsPathFiles.toMap -> partitionPaths.toSet) } override def abortTask(taskContext: TaskAttemptContext): Unit = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80cdc61484c0f..5d6edf6b8abec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1068,6 +1068,24 @@ object SQLConf { .timeConf(TimeUnit.MILLISECONDS) .createWithDefault(100) + object PartitionOverwriteMode extends Enumeration { + val STATIC, DYNAMIC = Value + } + + val PARTITION_OVERWRITE_MODE = + buildConf("spark.sql.sources.partitionOverwriteMode") + .doc("When INSERT OVERWRITE a partitioned data source table, we currently support 2 modes: " + + "static and dynamic. In static mode, Spark deletes all the partitions that match the " + + "partition specification(e.g. PARTITION(a=1,b)) in the INSERT statement, before " + + "overwriting. In dynamic mode, Spark doesn't delete partitions ahead, and only overwrite " + + "those partitions that have data written into it at runtime. By default we use static " + + "mode to keep the same behavior of Spark prior to 2.3. Note that this config doesn't " + + "affect Hive serde tables, as they are always overwritten with dynamic mode.") + .stringConf + .transform(_.toUpperCase(Locale.ROOT)) + .checkValues(PartitionOverwriteMode.values.map(_.toString)) + .createWithDefault(PartitionOverwriteMode.STATIC.toString) + object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } @@ -1394,6 +1412,9 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = + PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala index ad24e280d942a..dd7ef0d15c140 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoHadoopFsRelationCommand.scala @@ -29,6 +29,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.command._ +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.util.SchemaUtils /** @@ -89,13 +90,19 @@ case class InsertIntoHadoopFsRelationCommand( } val pathExists = fs.exists(qualifiedOutputPath) - // If we are appending data to an existing dir. - val isAppend = pathExists && (mode == SaveMode.Append) + + val enableDynamicOverwrite = + sparkSession.sessionState.conf.partitionOverwriteMode == PartitionOverwriteMode.DYNAMIC + // This config only makes sense when we are overwriting a partitioned dataset with dynamic + // partition columns. + val dynamicPartitionOverwrite = enableDynamicOverwrite && mode == SaveMode.Overwrite && + staticPartitions.size < partitionColumns.length val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, - outputPath = outputPath.toString) + outputPath = outputPath.toString, + dynamicPartitionOverwrite = dynamicPartitionOverwrite) val doInsertion = (mode, pathExists) match { case (SaveMode.ErrorIfExists, true) => @@ -103,6 +110,9 @@ case class InsertIntoHadoopFsRelationCommand( case (SaveMode.Overwrite, true) => if (ifPartitionNotExists && matchingPartitions.nonEmpty) { false + } else if (dynamicPartitionOverwrite) { + // For dynamic partition overwrite, do not delete partition directories ahead. + true } else { deleteMatchingPartitions(fs, qualifiedOutputPath, customPartitionLocations, committer) true @@ -126,7 +136,9 @@ case class InsertIntoHadoopFsRelationCommand( catalogTable.get.identifier, newPartitions.toSeq.map(p => (p, None)), ifNotExists = true).run(sparkSession) } - if (mode == SaveMode.Overwrite) { + // For dynamic partition overwrite, we never remove partitions but only update existing + // ones. + if (mode == SaveMode.Overwrite && !dynamicPartitionOverwrite) { val deletedPartitions = initialMatchingPartitions.toSet -- updatedPartitions if (deletedPartitions.nonEmpty) { AlterTableDropPartitionCommand( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala index 40825a1f724b1..39c594a9bc618 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/SQLHadoopMapReduceCommitProtocol.scala @@ -29,11 +29,15 @@ import org.apache.spark.sql.internal.SQLConf * A variant of [[HadoopMapReduceCommitProtocol]] that allows specifying the actual * Hadoop output committer using an option specified in SQLConf. */ -class SQLHadoopMapReduceCommitProtocol(jobId: String, path: String) - extends HadoopMapReduceCommitProtocol(jobId, path) with Serializable with Logging { +class SQLHadoopMapReduceCommitProtocol( + jobId: String, + path: String, + dynamicPartitionOverwrite: Boolean = false) + extends HadoopMapReduceCommitProtocol(jobId, path, dynamicPartitionOverwrite) + with Serializable with Logging { override protected def setupCommitter(context: TaskAttemptContext): OutputCommitter = { - var committer = context.getOutputFormatClass.newInstance().getOutputCommitter(context) + var committer = super.setupCommitter(context) val configuration = context.getConfiguration val clazz = diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index 8b7e2e5f45946..fef01c860db6e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -21,6 +21,8 @@ import java.io.File import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, Row} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.internal.SQLConf.PartitionOverwriteMode import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.util.Utils @@ -442,4 +444,80 @@ class InsertSuite extends DataSourceTest with SharedSQLContext { assert(e.contains("Only Data Sources providing FileFormat are supported")) } } + + test("SPARK-20236: dynamic partition overwrite without catalog table") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTempPath { path => + Seq((1, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(1, 1, 1)) + + Seq((2, 1, 1)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1)) + + Seq((2, 2, 2)).toDF("i", "part1", "part2") + .write.partitionBy("part1", "part2").mode("overwrite").parquet(path.getAbsolutePath) + checkAnswer(spark.read.parquet(path.getAbsolutePath), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + sql("insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } + + test("SPARK-20236: dynamic partition overwrite with customer partition path") { + withSQLConf(SQLConf.PARTITION_OVERWRITE_MODE.key -> PartitionOverwriteMode.DYNAMIC.toString) { + withTable("t") { + sql( + """ + |create table t(i int, part1 int, part2 int) using parquet + |partitioned by (part1, part2) + """.stripMargin) + + val path1 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=1) location '$path1'") + sql(s"insert into t partition(part1=1, part2=1) select 1") + checkAnswer(spark.table("t"), Row(1, 1, 1)) + + sql("insert overwrite table t partition(part1=1, part2=1) select 2") + checkAnswer(spark.table("t"), Row(2, 1, 1)) + + sql("insert overwrite table t partition(part1=2, part2) select 2, 2") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Nil) + + val path2 = Utils.createTempDir() + sql(s"alter table t add partition(part1=1, part2=2) location '$path2'") + sql("insert overwrite table t partition(part1=1, part2=2) select 3") + checkAnswer(spark.table("t"), Row(2, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + + sql("insert overwrite table t partition(part1=1, part2) select 4, 1") + checkAnswer(spark.table("t"), Row(4, 1, 1) :: Row(2, 2, 2) :: Row(3, 1, 2) :: Nil) + } + } + } } From 27c949d673e45fdbbae0f2c08969b9d51222dd8d Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Tue, 2 Jan 2018 09:19:18 +0800 Subject: [PATCH 18/55] [SPARK-22932][SQL] Refactor AnalysisContext ## What changes were proposed in this pull request? Add a `reset` function to ensure the state in `AnalysisContext ` is per-query. ## How was this patch tested? The existing test cases Author: gatorsmile Closes #20127 from gatorsmile/refactorAnalysisContext. --- .../sql/catalyst/analysis/Analyzer.scala | 25 +++++++++++++++---- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 6d294d48c0ee7..35b35110e491f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -52,6 +52,7 @@ object SimpleAnalyzer extends Analyzer( /** * Provides a way to keep state during the analysis, this enables us to decouple the concerns * of analysis environment from the catalog. + * The state that is kept here is per-query. * * Note this is thread local. * @@ -70,6 +71,8 @@ object AnalysisContext { } def get: AnalysisContext = value.get() + def reset(): Unit = value.remove() + private def set(context: AnalysisContext): Unit = value.set(context) def withAnalysisContext[A](database: Option[String])(f: => A): A = { @@ -95,6 +98,17 @@ class Analyzer( this(catalog, conf, conf.optimizerMaxIterations) } + override def execute(plan: LogicalPlan): LogicalPlan = { + AnalysisContext.reset() + try { + executeSameContext(plan) + } finally { + AnalysisContext.reset() + } + } + + private def executeSameContext(plan: LogicalPlan): LogicalPlan = super.execute(plan) + def resolver: Resolver = conf.resolver protected val fixedPoint = FixedPoint(maxIterations) @@ -176,7 +190,7 @@ class Analyzer( case With(child, relations) => substituteCTE(child, relations.foldLeft(Seq.empty[(String, LogicalPlan)]) { case (resolved, (name, relation)) => - resolved :+ name -> execute(substituteCTE(relation, resolved)) + resolved :+ name -> executeSameContext(substituteCTE(relation, resolved)) }) case other => other } @@ -600,7 +614,7 @@ class Analyzer( "avoid errors. Increase the value of spark.sql.view.maxNestedViewDepth to work " + "aroud this.") } - execute(child) + executeSameContext(child) } view.copy(child = newChild) case p @ SubqueryAlias(_, view: View) => @@ -1269,7 +1283,7 @@ class Analyzer( do { // Try to resolve the subquery plan using the regular analyzer. previous = current - current = execute(current) + current = executeSameContext(current) // Use the outer references to resolve the subquery plan if it isn't resolved yet. val i = plans.iterator @@ -1392,7 +1406,7 @@ class Analyzer( grouping, Alias(cond, "havingCondition")() :: Nil, child) - val resolvedOperator = execute(aggregatedCondition) + val resolvedOperator = executeSameContext(aggregatedCondition) def resolvedAggregateFilter = resolvedOperator .asInstanceOf[Aggregate] @@ -1450,7 +1464,8 @@ class Analyzer( val aliasedOrdering = unresolvedSortOrders.map(o => Alias(o.child, "aggOrder")()) val aggregatedOrdering = aggregate.copy(aggregateExpressions = aliasedOrdering) - val resolvedAggregate: Aggregate = execute(aggregatedOrdering).asInstanceOf[Aggregate] + val resolvedAggregate: Aggregate = + executeSameContext(aggregatedOrdering).asInstanceOf[Aggregate] val resolvedAliasedOrdering: Seq[Alias] = resolvedAggregate.aggregateExpressions.asInstanceOf[Seq[Alias]] From 79f7263daa5f83e2026fda9a8bbb1090a1333f80 Mon Sep 17 00:00:00 2001 From: chetkhatri Date: Wed, 3 Jan 2018 11:31:32 -0600 Subject: [PATCH 19/55] [SPARK-22896] Improvement in String interpolation ## What changes were proposed in this pull request? * String interpolation in ml pipeline example has been corrected as per scala standard. ## How was this patch tested? * manually tested. Author: chetkhatri Closes #20070 from chetkhatri/mllib-chetan-contrib. (cherry picked from commit 9a2b65a3c0c36316aae0a53aa0f61c5044c2ceff) Signed-off-by: Sean Owen --- .../spark/examples/ml/JavaQuantileDiscretizerExample.java | 2 +- .../apache/spark/examples/SimpleSkewedGroupByTest.scala | 4 ---- .../org/apache/spark/examples/graphx/Analytics.scala | 6 ++++-- .../org/apache/spark/examples/graphx/SynthBenchmark.scala | 6 +++--- .../apache/spark/examples/ml/ChiSquareTestExample.scala | 6 +++--- .../org/apache/spark/examples/ml/CorrelationExample.scala | 4 ++-- .../org/apache/spark/examples/ml/DataFrameExample.scala | 4 ++-- .../examples/ml/DecisionTreeClassificationExample.scala | 4 ++-- .../spark/examples/ml/DecisionTreeRegressionExample.scala | 4 ++-- .../apache/spark/examples/ml/DeveloperApiExample.scala | 6 +++--- .../examples/ml/EstimatorTransformerParamExample.scala | 6 +++--- .../ml/GradientBoostedTreeClassifierExample.scala | 4 ++-- .../examples/ml/GradientBoostedTreeRegressorExample.scala | 4 ++-- ...ulticlassLogisticRegressionWithElasticNetExample.scala | 2 +- .../ml/MultilayerPerceptronClassifierExample.scala | 2 +- .../org/apache/spark/examples/ml/NaiveBayesExample.scala | 2 +- .../spark/examples/ml/QuantileDiscretizerExample.scala | 4 ++-- .../spark/examples/ml/RandomForestClassifierExample.scala | 4 ++-- .../spark/examples/ml/RandomForestRegressorExample.scala | 4 ++-- .../apache/spark/examples/ml/VectorIndexerExample.scala | 4 ++-- .../spark/examples/mllib/AssociationRulesExample.scala | 6 +++--- .../mllib/BinaryClassificationMetricsExample.scala | 4 ++-- .../mllib/DecisionTreeClassificationExample.scala | 4 ++-- .../examples/mllib/DecisionTreeRegressionExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/FPGrowthExample.scala | 2 +- .../mllib/GradientBoostingClassificationExample.scala | 4 ++-- .../mllib/GradientBoostingRegressionExample.scala | 4 ++-- .../spark/examples/mllib/HypothesisTestingExample.scala | 2 +- .../spark/examples/mllib/IsotonicRegressionExample.scala | 2 +- .../org/apache/spark/examples/mllib/KMeansExample.scala | 2 +- .../org/apache/spark/examples/mllib/LBFGSExample.scala | 2 +- .../examples/mllib/LatentDirichletAllocationExample.scala | 8 +++++--- .../examples/mllib/LinearRegressionWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/PCAExample.scala | 4 ++-- .../spark/examples/mllib/PMMLModelExportExample.scala | 2 +- .../apache/spark/examples/mllib/PrefixSpanExample.scala | 4 ++-- .../mllib/RandomForestClassificationExample.scala | 4 ++-- .../examples/mllib/RandomForestRegressionExample.scala | 4 ++-- .../spark/examples/mllib/RecommendationExample.scala | 2 +- .../apache/spark/examples/mllib/SVMWithSGDExample.scala | 2 +- .../org/apache/spark/examples/mllib/SimpleFPGrowth.scala | 8 +++----- .../spark/examples/mllib/StratifiedSamplingExample.scala | 4 ++-- .../org/apache/spark/examples/mllib/TallSkinnyPCA.scala | 2 +- .../org/apache/spark/examples/mllib/TallSkinnySVD.scala | 2 +- .../apache/spark/examples/streaming/CustomReceiver.scala | 6 +++--- .../apache/spark/examples/streaming/RawNetworkGrep.scala | 2 +- .../examples/streaming/RecoverableNetworkWordCount.scala | 8 ++++---- .../streaming/clickstream/PageViewGenerator.scala | 4 ++-- .../examples/streaming/clickstream/PageViewStream.scala | 4 ++-- 49 files changed, 94 insertions(+), 96 deletions(-) diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java index dd20cac621102..43cc30c1a899b 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaQuantileDiscretizerExample.java @@ -66,7 +66,7 @@ public static void main(String[] args) { .setNumBuckets(3); Dataset result = discretizer.fit(df).transform(df); - result.show(); + result.show(false); // $example off$ spark.stop(); } diff --git a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala index e64dcbd182d94..2332a661f26a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SimpleSkewedGroupByTest.scala @@ -60,10 +60,6 @@ object SimpleSkewedGroupByTest { pairs1.count println(s"RESULT: ${pairs1.groupByKey(numReducers).count}") - // Print how many keys each reducer got (for debugging) - // println("RESULT: " + pairs1.groupByKey(numReducers) - // .map{case (k,v) => (k, v.size)} - // .collectAsMap) spark.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala index 92936bd30dbc0..815404d1218b7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/Analytics.scala @@ -145,9 +145,11 @@ object Analytics extends Logging { // TriangleCount requires the graph to be partitioned .partitionBy(partitionStrategy.getOrElse(RandomVertexCut)).cache() val triangles = TriangleCount.run(graph) - println("Triangles: " + triangles.vertices.map { + val triangleTypes = triangles.vertices.map { case (vid, data) => data.toLong - }.reduce(_ + _) / 3) + }.reduce(_ + _) / 3 + + println(s"Triangles: ${triangleTypes}") sc.stop() case _ => diff --git a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala index 6d2228c8742aa..57b2edf992208 100644 --- a/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala +++ b/examples/src/main/scala/org/apache/spark/examples/graphx/SynthBenchmark.scala @@ -52,7 +52,7 @@ object SynthBenchmark { arg => arg.dropWhile(_ == '-').split('=') match { case Array(opt, v) => (opt -> v) - case _ => throw new IllegalArgumentException("Invalid argument: " + arg) + case _ => throw new IllegalArgumentException(s"Invalid argument: $arg") } } @@ -76,7 +76,7 @@ object SynthBenchmark { case ("sigma", v) => sigma = v.toDouble case ("degFile", v) => degFile = v case ("seed", v) => seed = v.toInt - case (opt, _) => throw new IllegalArgumentException("Invalid option: " + opt) + case (opt, _) => throw new IllegalArgumentException(s"Invalid option: $opt") } val conf = new SparkConf() @@ -86,7 +86,7 @@ object SynthBenchmark { val sc = new SparkContext(conf) // Create the graph - println(s"Creating graph...") + println("Creating graph...") val unpartitionedGraph = GraphGenerators.logNormalGraph(sc, numVertices, numEPart.getOrElse(sc.defaultParallelism), mu, sigma, seed) // Repartition the graph diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala index dcee1e427ce58..5146fd0316467 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/ChiSquareTestExample.scala @@ -52,9 +52,9 @@ object ChiSquareTestExample { val df = data.toDF("label", "features") val chi = ChiSquareTest.test(df, "features", "label").head - println("pValues = " + chi.getAs[Vector](0)) - println("degreesOfFreedom = " + chi.getSeq[Int](1).mkString("[", ",", "]")) - println("statistics = " + chi.getAs[Vector](2)) + println(s"pValues = ${chi.getAs[Vector](0)}") + println(s"degreesOfFreedom ${chi.getSeq[Int](1).mkString("[", ",", "]")}") + println(s"statistics ${chi.getAs[Vector](2)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala index 3f57dc342eb00..d7f1fc8ed74d7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/CorrelationExample.scala @@ -51,10 +51,10 @@ object CorrelationExample { val df = data.map(Tuple1.apply).toDF("features") val Row(coeff1: Matrix) = Correlation.corr(df, "features").head - println("Pearson correlation matrix:\n" + coeff1.toString) + println(s"Pearson correlation matrix:\n $coeff1") val Row(coeff2: Matrix) = Correlation.corr(df, "features", "spearman").head - println("Spearman correlation matrix:\n" + coeff2.toString) + println(s"Spearman correlation matrix:\n $coeff2") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala index 0658bddf16961..ee4469faab3a0 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DataFrameExample.scala @@ -47,7 +47,7 @@ object DataFrameExample { val parser = new OptionParser[Params]("DataFrameExample") { head("DataFrameExample: an example app using DataFrame for ML.") opt[String]("input") - .text(s"input path to dataframe") + .text("input path to dataframe") .action((x, c) => c.copy(input = x)) checkConfig { params => success @@ -93,7 +93,7 @@ object DataFrameExample { // Load the records back. println(s"Loading Parquet file with UDT from $outputDir.") val newDF = spark.read.parquet(outputDir) - println(s"Schema from Parquet:") + println("Schema from Parquet:") newDF.printSchema() spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala index bc6d3275933ea..276cedab11abc 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeClassificationExample.scala @@ -83,10 +83,10 @@ object DecisionTreeClassificationExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel] - println("Learned classification tree model:\n" + treeModel.toDebugString) + println(s"Learned classification tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala index ee61200ad1d0c..aaaecaea47081 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeRegressionExample.scala @@ -73,10 +73,10 @@ object DecisionTreeRegressionExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val treeModel = model.stages(1).asInstanceOf[DecisionTreeRegressionModel] - println("Learned regression tree model:\n" + treeModel.toDebugString) + println(s"Learned regression tree model:\n ${treeModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index d94d837d10e96..2dc11b07d88ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -53,7 +53,7 @@ object DeveloperApiExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new MyLogisticRegression() // Print out the parameters, documentation, and any default values. - println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"MyLogisticRegression parameters:\n ${lr.explainParams()}") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -169,10 +169,10 @@ private class MyLogisticRegressionModel( Vectors.dense(-margin, margin) } - /** Number of classes the label can take. 2 indicates binary classification. */ + // Number of classes the label can take. 2 indicates binary classification. override val numClasses: Int = 2 - /** Number of features the model was trained on. */ + // Number of features the model was trained on. override val numFeatures: Int = coefficients.size /** diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala index f18d86e1a6921..e5d91f132a3f2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/EstimatorTransformerParamExample.scala @@ -46,7 +46,7 @@ object EstimatorTransformerParamExample { // Create a LogisticRegression instance. This instance is an Estimator. val lr = new LogisticRegression() // Print out the parameters, documentation, and any default values. - println("LogisticRegression parameters:\n" + lr.explainParams() + "\n") + println(s"LogisticRegression parameters:\n ${lr.explainParams()}\n") // We may set parameters using setter methods. lr.setMaxIter(10) @@ -58,7 +58,7 @@ object EstimatorTransformerParamExample { // we can view the parameters it used during fit(). // This prints the parameter (name: value) pairs, where names are unique IDs for this // LogisticRegression instance. - println("Model 1 was fit using parameters: " + model1.parent.extractParamMap) + println(s"Model 1 was fit using parameters: ${model1.parent.extractParamMap}") // We may alternatively specify parameters using a ParamMap, // which supports several methods for specifying parameters. @@ -73,7 +73,7 @@ object EstimatorTransformerParamExample { // Now learn a new model using the paramMapCombined parameters. // paramMapCombined overrides all parameters set earlier via lr.set* methods. val model2 = lr.fit(training, paramMapCombined) - println("Model 2 was fit using parameters: " + model2.parent.extractParamMap) + println(s"Model 2 was fit using parameters: ${model2.parent.extractParamMap}") // Prepare test data. val test = spark.createDataFrame(Seq( diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala index 3656773c8b817..ef78c0a1145ef 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeClassifierExample.scala @@ -86,10 +86,10 @@ object GradientBoostedTreeClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${1.0 - accuracy}") val gbtModel = model.stages(2).asInstanceOf[GBTClassificationModel] - println("Learned classification GBT model:\n" + gbtModel.toDebugString) + println(s"Learned classification GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala index e53aab7f326d3..3feb2343f6a85 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/GradientBoostedTreeRegressorExample.scala @@ -73,10 +73,10 @@ object GradientBoostedTreeRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val gbtModel = model.stages(1).asInstanceOf[GBTRegressionModel] - println("Learned regression GBT model:\n" + gbtModel.toDebugString) + println(s"Learned regression GBT model:\n ${gbtModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala index 42f0ace7a353d..3e61dbe628c20 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MulticlassLogisticRegressionWithElasticNetExample.scala @@ -48,7 +48,7 @@ object MulticlassLogisticRegressionWithElasticNetExample { // Print the coefficients and intercept for multinomial logistic regression println(s"Coefficients: \n${lrModel.coefficientMatrix}") - println(s"Intercepts: ${lrModel.interceptVector}") + println(s"Intercepts: \n${lrModel.interceptVector}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala index 6fce82d294f8d..646f46a925062 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/MultilayerPerceptronClassifierExample.scala @@ -66,7 +66,7 @@ object MultilayerPerceptronClassifierExample { val evaluator = new MulticlassClassificationEvaluator() .setMetricName("accuracy") - println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels)) + println(s"Test set accuracy = ${evaluator.evaluate(predictionAndLabels)}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala index bd9fcc420a66c..50c70c626b128 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/NaiveBayesExample.scala @@ -52,7 +52,7 @@ object NaiveBayesExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test set accuracy = " + accuracy) + println(s"Test set accuracy = $accuracy") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala index aedb9e7d3bb70..0fe16fb6dfa9f 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/QuantileDiscretizerExample.scala @@ -36,7 +36,7 @@ object QuantileDiscretizerExample { // Output of QuantileDiscretizer for such small datasets can depend on the number of // partitions. Here we force a single partition to ensure consistent results. // Note this is not necessary for normal use cases - .repartition(1) + .repartition(1) // $example on$ val discretizer = new QuantileDiscretizer() @@ -45,7 +45,7 @@ object QuantileDiscretizerExample { .setNumBuckets(3) val result = discretizer.fit(df).transform(df) - result.show() + result.show(false) // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala index 5eafda8ce4285..6265f83902528 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestClassifierExample.scala @@ -85,10 +85,10 @@ object RandomForestClassifierExample { .setPredictionCol("prediction") .setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) - println("Test Error = " + (1.0 - accuracy)) + println(s"Test Error = ${(1.0 - accuracy)}") val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] - println("Learned classification forest model:\n" + rfModel.toDebugString) + println(s"Learned classification forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala index 9a0a001c26ef5..2679fcb353a8a 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestRegressorExample.scala @@ -72,10 +72,10 @@ object RandomForestRegressorExample { .setPredictionCol("prediction") .setMetricName("rmse") val rmse = evaluator.evaluate(predictions) - println("Root Mean Squared Error (RMSE) on test data = " + rmse) + println(s"Root Mean Squared Error (RMSE) on test data = $rmse") val rfModel = model.stages(1).asInstanceOf[RandomForestRegressionModel] - println("Learned regression forest model:\n" + rfModel.toDebugString) + println(s"Learned regression forest model:\n ${rfModel.toDebugString}") // $example off$ spark.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala index afa761aee0b98..96bb8ea2338af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/VectorIndexerExample.scala @@ -41,8 +41,8 @@ object VectorIndexerExample { val indexerModel = indexer.fit(data) val categoricalFeatures: Set[Int] = indexerModel.categoryMaps.keys.toSet - println(s"Chose ${categoricalFeatures.size} categorical features: " + - categoricalFeatures.mkString(", ")) + println(s"Chose ${categoricalFeatures.size} " + + s"categorical features: ${categoricalFeatures.mkString(", ")}") // Create new column "indexed" with categorical values transformed to indices val indexedData = indexerModel.transform(data) diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala index ff44de56839e5..a07535bb5a38d 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/AssociationRulesExample.scala @@ -42,9 +42,8 @@ object AssociationRulesExample { val results = ar.run(freqItemsets) results.collect().foreach { rule => - println("[" + rule.antecedent.mkString(",") - + "=>" - + rule.consequent.mkString(",") + "]," + rule.confidence) + println(s"[${rule.antecedent.mkString(",")}=>${rule.consequent.mkString(",")} ]" + + s" ${rule.confidence}") } // $example off$ @@ -53,3 +52,4 @@ object AssociationRulesExample { } // scalastyle:on println + diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala index b9263ac6fcff6..c6312d71cc912 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/BinaryClassificationMetricsExample.scala @@ -86,7 +86,7 @@ object BinaryClassificationMetricsExample { // AUPRC val auPRC = metrics.areaUnderPR - println("Area under precision-recall curve = " + auPRC) + println(s"Area under precision-recall curve = $auPRC") // Compute thresholds used in ROC and PR curves val thresholds = precision.map(_._1) @@ -96,7 +96,7 @@ object BinaryClassificationMetricsExample { // AUROC val auROC = metrics.areaUnderROC - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala index b50b4592777ce..c2f89b72c9a2e 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeClassificationExample.scala @@ -55,8 +55,8 @@ object DecisionTreeClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count().toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification tree model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala index 2af45afae3d5b..1ecf6426e1f95 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRegressionExample.scala @@ -54,8 +54,8 @@ object DecisionTreeRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case (v, p) => math.pow(v - p, 2) }.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression tree model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression tree model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myDecisionTreeRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala index 6435abc127752..f724ee1030f04 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/FPGrowthExample.scala @@ -74,7 +74,7 @@ object FPGrowthExample { println(s"Number of frequent itemsets: ${model.freqItemsets.count()}") model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")}, ${itemset.freq}") } sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala index 00bb3348d2a36..3c56e1941aeca 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingClassificationExample.scala @@ -54,8 +54,8 @@ object GradientBoostingClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification GBT model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala index d8c263460839b..c288bf29bf255 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/GradientBoostingRegressionExample.scala @@ -53,8 +53,8 @@ object GradientBoostingRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression GBT model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression GBT model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myGradientBoostingRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala index 0d391a3637c07..add1719739539 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/HypothesisTestingExample.scala @@ -68,7 +68,7 @@ object HypothesisTestingExample { // against the label. val featureTestResults: Array[ChiSqTestResult] = Statistics.chiSqTest(obs) featureTestResults.zipWithIndex.foreach { case (k, v) => - println("Column " + (v + 1).toString + ":") + println(s"Column ${(v + 1)} :") println(k) } // summary of the test // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala index 4aee951f5b04c..a10d6f0dda880 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/IsotonicRegressionExample.scala @@ -56,7 +56,7 @@ object IsotonicRegressionExample { // Calculate mean squared error between predicted and real labels. val meanSquaredError = predictionAndLabel.map { case (p, l) => math.pow((p - l), 2) }.mean() - println("Mean Squared Error = " + meanSquaredError) + println(s"Mean Squared Error = $meanSquaredError") // Save and load model model.save(sc, "target/tmp/myIsotonicRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala index c4d71d862f375..b0a6f1671a898 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/KMeansExample.scala @@ -43,7 +43,7 @@ object KMeansExample { // Evaluate clustering by computing Within Set Sum of Squared Errors val WSSSE = clusters.computeCost(parsedData) - println("Within Set Sum of Squared Errors = " + WSSSE) + println(s"Within Set Sum of Squared Errors = $WSSSE") // Save and load model clusters.save(sc, "target/org/apache/spark/KMeansExample/KMeansModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala index fedcefa098381..123782fa6b9cf 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LBFGSExample.scala @@ -82,7 +82,7 @@ object LBFGSExample { println("Loss of each step in training process") loss.foreach(println) - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala index f2c8ec01439f1..d25962c5500ed 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LatentDirichletAllocationExample.scala @@ -42,11 +42,13 @@ object LatentDirichletAllocationExample { val ldaModel = new LDA().setK(3).run(corpus) // Output topics. Each is a distribution over words (matching word count vectors) - println("Learned topics (as distributions over vocab of " + ldaModel.vocabSize + " words):") + println(s"Learned topics (as distributions over vocab of ${ldaModel.vocabSize} words):") val topics = ldaModel.topicsMatrix for (topic <- Range(0, 3)) { - print("Topic " + topic + ":") - for (word <- Range(0, ldaModel.vocabSize)) { print(" " + topics(word, topic)); } + print(s"Topic $topic :") + for (word <- Range(0, ldaModel.vocabSize)) { + print(s"${topics(word, topic)}") + } println() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala index d399618094487..449b725d1d173 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LinearRegressionWithSGDExample.scala @@ -52,7 +52,7 @@ object LinearRegressionWithSGDExample { (point.label, prediction) } val MSE = valuesAndPreds.map{ case(v, p) => math.pow((v - p), 2) }.mean() - println("training Mean Squared Error = " + MSE) + println(s"training Mean Squared Error $MSE") // Save and load model model.save(sc, "target/tmp/scalaLinearRegressionWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala index eb36697d94ba1..eff2393cc3abe 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PCAExample.scala @@ -65,8 +65,8 @@ object PCAExample { val MSE = valuesAndPreds.map { case (v, p) => math.pow((v - p), 2) }.mean() val MSE_pca = valuesAndPreds_pca.map { case (v, p) => math.pow((v - p), 2) }.mean() - println("Mean Squared Error = " + MSE) - println("PCA Mean Squared Error = " + MSE_pca) + println(s"Mean Squared Error = $MSE") + println(s"PCA Mean Squared Error = $MSE_pca") // $example off$ sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala index d74d74a37fb11..96deafd469bc7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PMMLModelExportExample.scala @@ -41,7 +41,7 @@ object PMMLModelExportExample { val clusters = KMeans.train(parsedData, numClusters, numIterations) // Export to PMML to a String in PMML format - println("PMML Model:\n" + clusters.toPMML) + println(s"PMML Model:\n ${clusters.toPMML}") // Export the model to a local file in PMML format clusters.toPMML("/tmp/kmeans.xml") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala index 69c72c4336576..8b789277774af 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/PrefixSpanExample.scala @@ -42,8 +42,8 @@ object PrefixSpanExample { val model = prefixSpan.run(sequences) model.freqSequences.collect().foreach { freqSequence => println( - freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]") + - ", " + freqSequence.freq) + s"${freqSequence.sequence.map(_.mkString("[", ", ", "]")).mkString("[", ", ", "]")}," + + s" ${freqSequence.freq}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala index f1ebdf1a733ed..246e71de25615 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestClassificationExample.scala @@ -55,8 +55,8 @@ object RandomForestClassificationExample { (point.label, prediction) } val testErr = labelAndPreds.filter(r => r._1 != r._2).count.toDouble / testData.count() - println("Test Error = " + testErr) - println("Learned classification forest model:\n" + model.toDebugString) + println(s"Test Error = $testErr") + println(s"Learned classification forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestClassificationModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala index 11d612e651b4b..770e30276bc30 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RandomForestRegressionExample.scala @@ -55,8 +55,8 @@ object RandomForestRegressionExample { (point.label, prediction) } val testMSE = labelsAndPredictions.map{ case(v, p) => math.pow((v - p), 2)}.mean() - println("Test Mean Squared Error = " + testMSE) - println("Learned regression forest model:\n" + model.toDebugString) + println(s"Test Mean Squared Error = $testMSE") + println(s"Learned regression forest model:\n ${model.toDebugString}") // Save and load model model.save(sc, "target/tmp/myRandomForestRegressionModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala index 6df742d737e70..0bb2b8c8c2b43 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/RecommendationExample.scala @@ -56,7 +56,7 @@ object RecommendationExample { val err = (r1 - r2) err * err }.mean() - println("Mean Squared Error = " + MSE) + println(s"Mean Squared Error = $MSE") // Save and load model model.save(sc, "target/tmp/myCollaborativeFilter") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala index b73fe9b2b3faa..285e2ce512639 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SVMWithSGDExample.scala @@ -57,7 +57,7 @@ object SVMWithSGDExample { val metrics = new BinaryClassificationMetrics(scoreAndLabels) val auROC = metrics.areaUnderROC() - println("Area under ROC = " + auROC) + println(s"Area under ROC = $auROC") // Save and load model model.save(sc, "target/tmp/scalaSVMWithSGDModel") diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala index b5c3033bcba09..694c3bb18b045 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/SimpleFPGrowth.scala @@ -42,15 +42,13 @@ object SimpleFPGrowth { val model = fpg.run(transactions) model.freqItemsets.collect().foreach { itemset => - println(itemset.items.mkString("[", ",", "]") + ", " + itemset.freq) + println(s"${itemset.items.mkString("[", ",", "]")},${itemset.freq}") } val minConfidence = 0.8 model.generateAssociationRules(minConfidence).collect().foreach { rule => - println( - rule.antecedent.mkString("[", ",", "]") - + " => " + rule.consequent .mkString("[", ",", "]") - + ", " + rule.confidence) + println(s"${rule.antecedent.mkString("[", ",", "]")}=> " + + s"${rule.consequent .mkString("[", ",", "]")},${rule.confidence}") } // $example off$ diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala index 16b074ef60699..3d41bef0af88c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/StratifiedSamplingExample.scala @@ -41,10 +41,10 @@ object StratifiedSamplingExample { val exactSample = data.sampleByKeyExact(withReplacement = false, fractions = fractions) // $example off$ - println("approxSample size is " + approxSample.collect().size.toString) + println(s"approxSample size is ${approxSample.collect().size}") approxSample.collect().foreach(println) - println("exactSample its size is " + exactSample.collect().size.toString) + println(s"exactSample its size is ${exactSample.collect().size}") exactSample.collect().foreach(println) sc.stop() diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala index 03bc675299c5a..071d341b81614 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnyPCA.scala @@ -54,7 +54,7 @@ object TallSkinnyPCA { // Compute principal components. val pc = mat.computePrincipalComponents(mat.numCols().toInt) - println("Principal components are:\n" + pc) + println(s"Principal components are:\n $pc") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala index 067e49b9599e7..8ae6de16d80e7 100644 --- a/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala +++ b/examples/src/main/scala/org/apache/spark/examples/mllib/TallSkinnySVD.scala @@ -54,7 +54,7 @@ object TallSkinnySVD { // Compute SVD. val svd = mat.computeSVD(mat.numCols().toInt) - println("Singular values are " + svd.s) + println(s"Singular values are ${svd.s}") sc.stop() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala index 43044d01b1204..25c7bf2871972 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/CustomReceiver.scala @@ -82,9 +82,9 @@ class CustomReceiver(host: String, port: Int) var socket: Socket = null var userInput: String = null try { - logInfo("Connecting to " + host + ":" + port) + logInfo(s"Connecting to $host : $port") socket = new Socket(host, port) - logInfo("Connected to " + host + ":" + port) + logInfo(s"Connected to $host : $port") val reader = new BufferedReader( new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)) userInput = reader.readLine() @@ -98,7 +98,7 @@ class CustomReceiver(host: String, port: Int) restart("Trying to connect again") } catch { case e: java.net.ConnectException => - restart("Error connecting to " + host + ":" + port, e) + restart(s"Error connecting to $host : $port", e) case t: Throwable => restart("Error receiving data", t) } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala index 5322929d177b4..437ccf0898d7c 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RawNetworkGrep.scala @@ -54,7 +54,7 @@ object RawNetworkGrep { ssc.rawSocketStream[String](host, port, StorageLevel.MEMORY_ONLY_SER_2)).toArray val union = ssc.union(rawStreams) union.filter(_.contains("the")).count().foreachRDD(r => - println("Grep count: " + r.collect().mkString)) + println(s"Grep count: ${r.collect().mkString}")) ssc.start() ssc.awaitTermination() } diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala index 49c0427321133..f018f3a26d2e9 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/RecoverableNetworkWordCount.scala @@ -130,10 +130,10 @@ object RecoverableNetworkWordCount { true } }.collect().mkString("[", ", ", "]") - val output = "Counts at time " + time + " " + counts + val output = s"Counts at time $time $counts" println(output) - println("Dropped " + droppedWordsCounter.value + " word(s) totally") - println("Appending to " + outputFile.getAbsolutePath) + println(s"Dropped ${droppedWordsCounter.value} word(s) totally") + println(s"Appending to ${outputFile.getAbsolutePath}") Files.append(output + "\n", outputFile, Charset.defaultCharset()) } ssc @@ -141,7 +141,7 @@ object RecoverableNetworkWordCount { def main(args: Array[String]) { if (args.length != 4) { - System.err.println("Your arguments were " + args.mkString("[", ", ", "]")) + System.err.println(s"Your arguments were ${args.mkString("[", ", ", "]")}") System.err.println( """ |Usage: RecoverableNetworkWordCount diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala index 0ddd065f0db2b..2108bc63edea2 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewGenerator.scala @@ -90,13 +90,13 @@ object PageViewGenerator { val viewsPerSecond = args(1).toFloat val sleepDelayMs = (1000.0 / viewsPerSecond).toInt val listener = new ServerSocket(port) - println("Listening on port: " + port) + println(s"Listening on port: $port") while (true) { val socket = listener.accept() new Thread() { override def run(): Unit = { - println("Got client connected from: " + socket.getInetAddress) + println(s"Got client connected from: ${socket.getInetAddress}") val out = new PrintWriter(socket.getOutputStream(), true) while (true) { diff --git a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala index 1ba093f57b32c..b8e7c7e9e9152 100644 --- a/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala +++ b/examples/src/main/scala/org/apache/spark/examples/streaming/clickstream/PageViewStream.scala @@ -104,8 +104,8 @@ object PageViewStream { .foreachRDD((rdd, time) => rdd.join(userList) .map(_._2._2) .take(10) - .foreach(u => println("Saw user %s at time %s".format(u, time)))) - case _ => println("Invalid metric entered: " + metric) + .foreach(u => println(s"Saw user $u at time $time"))) + case _ => println(s"Invalid metric entered: $metric") } ssc.start() From a51212b642f05f28447b80aa29f5482de2c27f58 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 07:28:53 +0800 Subject: [PATCH 20/55] [SPARK-20960][SQL] make ColumnVector public ## What changes were proposed in this pull request? move `ColumnVector` and related classes to `org.apache.spark.sql.vectorized`, and improve the document. ## How was this patch tested? existing tests. Author: Wenchen Fan Closes #20116 from cloud-fan/column-vector. (cherry picked from commit b297029130735316e1ac1144dee44761a12bfba7) Signed-off-by: gatorsmile --- .../VectorizedParquetRecordReader.java | 7 ++- .../vectorized/ColumnVectorUtils.java | 2 + .../vectorized/MutableColumnarRow.java | 4 ++ .../vectorized/WritableColumnVector.java | 7 ++- .../vectorized/ArrowColumnVector.java | 62 +------------------ .../vectorized/ColumnVector.java | 31 ++++++---- .../vectorized/ColumnarArray.java | 7 +-- .../vectorized/ColumnarBatch.java | 34 +++------- .../vectorized/ColumnarRow.java | 7 +-- .../sql/execution/ColumnarBatchScan.scala | 4 +- .../aggregate/HashAggregateExec.scala | 2 +- .../VectorizedHashMapGenerator.scala | 3 +- .../sql/execution/arrow/ArrowConverters.scala | 2 +- .../columnar/InMemoryTableScanExec.scala | 1 + .../execution/datasources/FileScanRDD.scala | 2 +- .../execution/python/ArrowPythonRunner.scala | 2 +- .../execution/arrow/ArrowWriterSuite.scala | 2 +- .../vectorized/ArrowColumnVectorSuite.scala | 1 + .../vectorized/ColumnVectorSuite.scala | 2 +- .../vectorized/ColumnarBatchSuite.scala | 6 +- 20 files changed, 63 insertions(+), 125 deletions(-) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ArrowColumnVector.java (94%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnVector.java (79%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarArray.java (95%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarBatch.java (73%) rename sql/core/src/main/java/org/apache/spark/sql/{execution => }/vectorized/ColumnarRow.java (96%) diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 6c157e85d411f..cd745b1f0e4e3 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -31,10 +31,10 @@ import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; -import org.apache.spark.sql.execution.vectorized.ColumnarBatch; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; @@ -248,7 +248,10 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - columnarBatch.reset(); + for (WritableColumnVector vector : columnVectors) { + vector.reset(); + } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java index bc62bc43484e5..b5cbe8e2839ba 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVectorUtils.java @@ -28,6 +28,8 @@ import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.util.DateTimeUtils; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java index 06602c147dfe9..70057a9def6c0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/MutableColumnarRow.java @@ -23,6 +23,10 @@ import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.apache.spark.sql.vectorized.ColumnarRow; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 5f6f125976e12..d2ae32b06f83b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -23,6 +23,7 @@ import org.apache.spark.sql.internal.SQLConf; import org.apache.spark.sql.types.*; +import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.UTF8String; @@ -585,11 +586,11 @@ public final int appendArray(int length) { public final int appendStruct(boolean isNull) { if (isNull) { appendNull(); - for (ColumnVector c: childColumns) { + for (WritableColumnVector c: childColumns) { if (c.type instanceof StructType) { - ((WritableColumnVector) c).appendStruct(true); + c.appendStruct(true); } else { - ((WritableColumnVector) c).appendNull(); + c.appendNull(); } } } else { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java similarity index 94% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java index af5673e26a501..708333213f3f1 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ArrowColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ArrowColumnVector.java @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.*; @@ -34,11 +34,7 @@ public final class ArrowColumnVector extends ColumnVector { private ArrowColumnVector[] childColumns; private void ensureAccessible(int index) { - int valueCount = accessor.getValueCount(); - if (index < 0 || index >= valueCount) { - throw new IndexOutOfBoundsException( - String.format("index: %d, valueCount: %d", index, valueCount)); - } + ensureAccessible(index, 1); } private void ensureAccessible(int index, int count) { @@ -64,20 +60,12 @@ public void close() { accessor.close(); } - // - // APIs dealing with nulls - // - @Override public boolean isNullAt(int rowId) { ensureAccessible(rowId); return accessor.isNullAt(rowId); } - // - // APIs dealing with Booleans - // - @Override public boolean getBoolean(int rowId) { ensureAccessible(rowId); @@ -94,10 +82,6 @@ public boolean[] getBooleans(int rowId, int count) { return array; } - // - // APIs dealing with Bytes - // - @Override public byte getByte(int rowId) { ensureAccessible(rowId); @@ -114,10 +98,6 @@ public byte[] getBytes(int rowId, int count) { return array; } - // - // APIs dealing with Shorts - // - @Override public short getShort(int rowId) { ensureAccessible(rowId); @@ -134,10 +114,6 @@ public short[] getShorts(int rowId, int count) { return array; } - // - // APIs dealing with Ints - // - @Override public int getInt(int rowId) { ensureAccessible(rowId); @@ -154,10 +130,6 @@ public int[] getInts(int rowId, int count) { return array; } - // - // APIs dealing with Longs - // - @Override public long getLong(int rowId) { ensureAccessible(rowId); @@ -174,10 +146,6 @@ public long[] getLongs(int rowId, int count) { return array; } - // - // APIs dealing with floats - // - @Override public float getFloat(int rowId) { ensureAccessible(rowId); @@ -194,10 +162,6 @@ public float[] getFloats(int rowId, int count) { return array; } - // - // APIs dealing with doubles - // - @Override public double getDouble(int rowId) { ensureAccessible(rowId); @@ -214,10 +178,6 @@ public double[] getDoubles(int rowId, int count) { return array; } - // - // APIs dealing with Arrays - // - @Override public int getArrayLength(int rowId) { ensureAccessible(rowId); @@ -230,45 +190,27 @@ public int getArrayOffset(int rowId) { return accessor.getArrayOffset(rowId); } - // - // APIs dealing with Decimals - // - @Override public Decimal getDecimal(int rowId, int precision, int scale) { ensureAccessible(rowId); return accessor.getDecimal(rowId, precision, scale); } - // - // APIs dealing with UTF8Strings - // - @Override public UTF8String getUTF8String(int rowId) { ensureAccessible(rowId); return accessor.getUTF8String(rowId); } - // - // APIs dealing with Binaries - // - @Override public byte[] getBinary(int rowId) { ensureAccessible(rowId); return accessor.getBinary(rowId); } - /** - * Returns the data for the underlying array. - */ @Override public ArrowColumnVector arrayData() { return childColumns[0]; } - /** - * Returns the ordinal's child data column. - */ @Override public ArrowColumnVector getChildColumn(int ordinal) { return childColumns[ordinal]; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java similarity index 79% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index dc7c1269bedd9..d1196e1299fee 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.MapData; import org.apache.spark.sql.types.DataType; @@ -22,24 +22,31 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * This class represents in-memory values of a column and provides the main APIs to access the data. - * It supports all the types and contains get APIs as well as their batched versions. The batched - * versions are considered to be faster and preferable whenever possible. + * An interface representing in-memory columnar data in Spark. This interface defines the main APIs + * to access the data, as well as their batched versions. The batched versions are considered to be + * faster and preferable whenever possible. * - * To handle nested schemas, ColumnVector has two types: Arrays and Structs. In both cases these - * columns have child columns. All of the data are stored in the child columns and the parent column - * only contains nullability. In the case of Arrays, the lengths and offsets are saved in the child - * column and are encoded identically to INTs. + * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values + * in this ColumnVector. * - * Maps are just a special case of a two field struct. + * ColumnVector supports all the data types including nested types. To handle nested types, + * ColumnVector can have children and is a tree structure. For struct type, it stores the actual + * data of each field in the corresponding child ColumnVector, and only stores null information in + * the parent ColumnVector. For array type, it stores the actual array elements in the child + * ColumnVector, and stores null information, array offsets and lengths in the parent ColumnVector. * - * Most of the APIs take the rowId as a parameter. This is the batch local 0-based row id for values - * in the current batch. + * ColumnVector is expected to be reused during the entire data loading process, to avoid allocating + * memory again and again. + * + * ColumnVector is meant to maximize CPU efficiency but not to minimize storage footprint. + * Implementations should prefer computing efficiency over storage efficiency when design the + * format. Since it is expected to reuse the ColumnVector instance while loading data, the storage + * footprint is negligible. */ public abstract class ColumnVector implements AutoCloseable { /** - * Returns the data type of this column. + * Returns the data type of this column vector. */ public final DataType dataType() { return type; } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java similarity index 95% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java index cbc39d1d0aec2..0d89a52e7a4fe 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarArray.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarArray.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.util.ArrayData; import org.apache.spark.sql.catalyst.util.MapData; @@ -23,8 +23,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Array abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Array abstraction in {@link ColumnVector}. */ public final class ColumnarArray extends ArrayData { // The data for this array. This array contains elements from @@ -33,7 +32,7 @@ public final class ColumnarArray extends ArrayData { private final int offset; private final int length; - ColumnarArray(ColumnVector data, int offset, int length) { + public ColumnarArray(ColumnVector data, int offset, int length) { this.data = data; this.offset = offset; this.length = length; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java similarity index 73% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java index a9d09aa679726..9ae1c6d9993f0 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarBatch.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarBatch.java @@ -14,26 +14,18 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import java.util.*; import org.apache.spark.sql.catalyst.InternalRow; +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow; import org.apache.spark.sql.types.StructType; /** - * This class is the in memory representation of rows as they are streamed through operators. It - * is designed to maximize CPU efficiency and not storage footprint. Since it is expected that - * each operator allocates one of these objects, the storage footprint on the task is negligible. - * - * The layout is a columnar with values encoded in their native format. Each RowBatch contains - * a horizontal partitioning of the data, split into columns. - * - * The ColumnarBatch supports either on heap or offheap modes with (mostly) the identical API. - * - * TODO: - * - There are many TODOs for the existing APIs. They should throw a not implemented exception. - * - Compaction: The batch and columns should be able to compact based on a selection vector. + * This class wraps multiple ColumnVectors as a row-wise table. It provides a row view of this + * batch so that Spark can access the data row by row. Instance of it is meant to be reused during + * the entire data loading process. */ public final class ColumnarBatch { public static final int DEFAULT_BATCH_SIZE = 4 * 1024; @@ -57,7 +49,7 @@ public void close() { } /** - * Returns an iterator over the rows in this batch. This skips rows that are filtered out. + * Returns an iterator over the rows in this batch. */ public Iterator rowIterator() { final int maxRows = numRows; @@ -87,19 +79,7 @@ public void remove() { } /** - * Resets the batch for writing. - */ - public void reset() { - for (int i = 0; i < numCols(); ++i) { - if (columns[i] instanceof WritableColumnVector) { - ((WritableColumnVector) columns[i]).reset(); - } - } - this.numRows = 0; - } - - /** - * Sets the number of rows that are valid. + * Sets the number of rows in this batch. */ public void setNumRows(int numRows) { assert(numRows <= this.capacity); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java similarity index 96% rename from sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java rename to sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java index 8bb33ed5b78c0..3c6656dec77cd 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/ColumnarRow.java +++ b/sql/core/src/main/java/org/apache/spark/sql/vectorized/ColumnarRow.java @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.sql.execution.vectorized; +package org.apache.spark.sql.vectorized; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.catalyst.expressions.GenericInternalRow; @@ -24,8 +24,7 @@ import org.apache.spark.unsafe.types.UTF8String; /** - * Row abstraction in {@link ColumnVector}. The instance of this class is intended - * to be reused, callers should copy the data out if it needs to be stored. + * Row abstraction in {@link ColumnVector}. */ public final class ColumnarRow extends InternalRow { // The data for this row. @@ -34,7 +33,7 @@ public final class ColumnarRow extends InternalRow { private final int rowId; private final int numFields; - ColumnarRow(ColumnVector data, int rowId) { + public ColumnarRow(ColumnVector data, int rowId) { assert (data.dataType() instanceof StructType); this.data = data; this.rowId = rowId; diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala index 782cec5e292ba..5617046e1396e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ColumnarBatchScan.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql.execution import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode} import org.apache.spark.sql.execution.metric.SQLMetrics -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.sql.types.DataType +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} /** * Helper trait for abstracting scan functionality using - * [[org.apache.spark.sql.execution.vectorized.ColumnarBatch]]es. + * [[ColumnarBatch]]es. */ private[sql] trait ColumnarBatchScan extends CodegenSupport { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 9a6f1c6dfa6a9..ce3c68810f3b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution._ import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} -import org.apache.spark.sql.execution.vectorized.{ColumnarRow, MutableColumnarRow} +import org.apache.spark.sql.execution.vectorized.MutableColumnarRow import org.apache.spark.sql.types.{DecimalType, StringType, StructType} import org.apache.spark.unsafe.KVIterator import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala index 0380ee8b09d63..0cf9b53ce1d5d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/VectorizedHashMapGenerator.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext -import org.apache.spark.sql.execution.vectorized.{ColumnarBatch, MutableColumnarRow, OnHeapColumnVector} +import org.apache.spark.sql.execution.vectorized.{MutableColumnarRow, OnHeapColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch /** * This is a helper class to generate an append-only vectorized hash map that can act as a 'cache' diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala index bcfc412430263..bcd1aa0890ba3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/arrow/ArrowConverters.scala @@ -32,8 +32,8 @@ import org.apache.spark.TaskContext import org.apache.spark.api.java.JavaRDD import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index 3e73393b12850..933b9753faa61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, Partition import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, WholeStageCodegenExec} import org.apache.spark.sql.execution.vectorized._ import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} case class InMemoryTableScanExec( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala index 8731ee88f87f2..835ce98462477 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileScanRDD.scala @@ -26,7 +26,7 @@ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.rdd.{InputFileBlockHolder, RDD} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.NextIterator /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala index 5cc8ed3535654..dc5ba96e69aec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ArrowPythonRunner.scala @@ -30,8 +30,8 @@ import org.apache.spark._ import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.{ArrowUtils, ArrowWriter} -import org.apache.spark.sql.execution.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.util.Utils /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala index 508c116aae92e..c42bc60a59d67 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/arrow/ArrowWriterSuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.execution.arrow import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.util.ArrayData -import org.apache.spark.sql.execution.vectorized.ArrowColumnVector import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowWriterSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala index 03490ad15a655..7304803a092c0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ArrowColumnVectorSuite.scala @@ -23,6 +23,7 @@ import org.apache.arrow.vector.complex._ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ArrowColumnVector import org.apache.spark.unsafe.types.UTF8String class ArrowColumnVectorSuite extends SparkFunSuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 54b31cee031f6..944240f3bade5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -21,10 +21,10 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow -import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.execution.columnar.ColumnAccessor import org.apache.spark.sql.execution.columnar.compression.ColumnBuilderHelper import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarArray import org.apache.spark.unsafe.types.UTF8String class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 7848ebdcab6d0..675f06b31b970 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -33,6 +33,7 @@ import org.apache.spark.sql.{RandomDataGenerator, Row} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.arrow.ArrowUtils import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnarBatch, ColumnVector} import org.apache.spark.unsafe.Platform import org.apache.spark.unsafe.types.CalendarInterval @@ -918,10 +919,7 @@ class ColumnarBatchSuite extends SparkFunSuite { assert(it.hasNext == false) // Reset and add 3 rows - batch.reset() - assert(batch.numRows() == 0) - assert(batch.rowIterator().hasNext == false) - + columns.foreach(_.reset()) // Add rows [NULL, 2.2, 2, "abc"], [3, NULL, 3, ""], [4, 4.4, 4, "world] columns(0).putNull(0) columns(1).putDouble(0, 2.2) From f51c8fde8bf08705bacf8a93b5dba685ebbcec17 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 13:14:52 +0800 Subject: [PATCH 21/55] [SPARK-22944][SQL] improve FoldablePropagation ## What changes were proposed in this pull request? `FoldablePropagation` is a little tricky as it needs to handle attributes that are miss-derived from children, e.g. outer join outputs. This rule does a kind of stop-able tree transform, to skip to apply this rule when hit a node which may have miss-derived attributes. Logically we should be able to apply this rule above the unsupported nodes, by just treating the unsupported nodes as leaf nodes. This PR improves this rule to not stop the tree transformation, but reduce the foldable expressions that we want to propagate. ## How was this patch tested? existing tests Author: Wenchen Fan Closes #20139 from cloud-fan/foldable. (cherry picked from commit 7d045c5f00e2c7c67011830e2169a4e130c3ace8) Signed-off-by: gatorsmile --- .../sql/catalyst/optimizer/expressions.scala | 65 +++++++++++-------- .../optimizer/FoldablePropagationSuite.scala | 23 ++++++- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala index 7d830bbb7dc32..1c0b7bd806801 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/expressions.scala @@ -506,18 +506,21 @@ object NullPropagation extends Rule[LogicalPlan] { /** - * Propagate foldable expressions: * Replace attributes with aliases of the original foldable expressions if possible. - * Other optimizations will take advantage of the propagated foldable expressions. - * + * Other optimizations will take advantage of the propagated foldable expressions. For example, + * this rule can optimize * {{{ * SELECT 1.0 x, 'abc' y, Now() z ORDER BY x, y, 3 - * ==> SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() * }}} + * to + * {{{ + * SELECT 1.0 x, 'abc' y, Now() z ORDER BY 1.0, 'abc', Now() + * }}} + * and other rules can further optimize it and remove the ORDER BY operator. */ object FoldablePropagation extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = { - val foldableMap = AttributeMap(plan.flatMap { + var foldableMap = AttributeMap(plan.flatMap { case Project(projectList, _) => projectList.collect { case a: Alias if a.child.foldable => (a.toAttribute, a) } @@ -530,38 +533,44 @@ object FoldablePropagation extends Rule[LogicalPlan] { if (foldableMap.isEmpty) { plan } else { - var stop = false CleanupAliases(plan.transformUp { - // A leaf node should not stop the folding process (note that we are traversing up the - // tree, starting at the leaf nodes); so we are allowing it. - case l: LeafNode => - l - // We can only propagate foldables for a subset of unary nodes. - case u: UnaryNode if !stop && canPropagateFoldables(u) => + case u: UnaryNode if foldableMap.nonEmpty && canPropagateFoldables(u) => u.transformExpressions(replaceFoldable) - // Allow inner joins. We do not allow outer join, although its output attributes are - // derived from its children, they are actually different attributes: the output of outer - // join is not always picked from its children, but can also be null. + // Join derives the output attributes from its child while they are actually not the + // same attributes. For example, the output of outer join is not always picked from its + // children, but can also be null. We should exclude these miss-derived attributes when + // propagating the foldable expressions. // TODO(cloud-fan): It seems more reasonable to use new attributes as the output attributes // of outer join. - case j @ Join(_, _, Inner, _) if !stop => - j.transformExpressions(replaceFoldable) - - // We can fold the projections an expand holds. However expand changes the output columns - // and often reuses the underlying attributes; so we cannot assume that a column is still - // foldable after the expand has been applied. - // TODO(hvanhovell): Expand should use new attributes as the output attributes. - case expand: Expand if !stop => - val newExpand = expand.copy(projections = expand.projections.map { projection => + case j @ Join(left, right, joinType, _) if foldableMap.nonEmpty => + val newJoin = j.transformExpressions(replaceFoldable) + val missDerivedAttrsSet: AttributeSet = AttributeSet(joinType match { + case _: InnerLike | LeftExistence(_) => Nil + case LeftOuter => right.output + case RightOuter => left.output + case FullOuter => left.output ++ right.output + }) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => missDerivedAttrsSet.contains(attr) + }.toSeq) + newJoin + + // We can not replace the attributes in `Expand.output`. If there are other non-leaf + // operators that have the `output` field, we should put them here too. + case expand: Expand if foldableMap.nonEmpty => + expand.copy(projections = expand.projections.map { projection => projection.map(_.transform(replaceFoldable)) }) - stop = true - newExpand - case other => - stop = true + // For other plans, they are not safe to apply foldable propagation, and they should not + // propagate foldable expressions from children. + case other if foldableMap.nonEmpty => + val childrenOutputSet = AttributeSet(other.children.flatMap(_.output)) + foldableMap = AttributeMap(foldableMap.baseMap.values.filterNot { + case (attr, _) => childrenOutputSet.contains(attr) + }.toSeq) other }) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala index dccb32f0379a8..c28844642aed0 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FoldablePropagationSuite.scala @@ -147,8 +147,8 @@ class FoldablePropagationSuite extends PlanTest { test("Propagate in expand") { val c1 = Literal(1).as('a) val c2 = Literal(2).as('b) - val a1 = c1.toAttribute.withNullability(true) - val a2 = c2.toAttribute.withNullability(true) + val a1 = c1.toAttribute.newInstance().withNullability(true) + val a2 = c2.toAttribute.newInstance().withNullability(true) val expand = Expand( Seq(Seq(Literal(null), 'b), Seq('a, Literal(null))), Seq(a1, a2), @@ -161,4 +161,23 @@ class FoldablePropagationSuite extends PlanTest { val correctAnswer = correctExpand.where(a1.isNotNull).select(a1, a2).analyze comparePlans(optimized, correctAnswer) } + + test("Propagate above outer join") { + val left = LocalRelation('a.int).select('a, Literal(1).as('b)) + val right = LocalRelation('c.int).select('c, Literal(1).as('d)) + + val join = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && 'b === 'd)) + val query = join.select(('b + 3).as('res)).analyze + val optimized = Optimize.execute(query) + + val correctAnswer = left.join( + right, + joinType = LeftOuter, + condition = Some('a === 'c && Literal(1) === Literal(1))) + .select((Literal(1) + 3).as('res)).analyze + comparePlans(optimized, correctAnswer) + } } From 1860a43e9affb7619be0a5a1c786e264d09bc446 Mon Sep 17 00:00:00 2001 From: Felix Cheung Date: Wed, 3 Jan 2018 21:43:14 -0800 Subject: [PATCH 22/55] [SPARK-22933][SPARKR] R Structured Streaming API for withWatermark, trigger, partitionBy ## What changes were proposed in this pull request? R Structured Streaming API for withWatermark, trigger, partitionBy ## How was this patch tested? manual, unit tests Author: Felix Cheung Closes #20129 from felixcheung/rwater. (cherry picked from commit df95a908baf78800556636a76d58bba9b3dd943f) Signed-off-by: Felix Cheung --- R/pkg/NAMESPACE | 1 + R/pkg/R/DataFrame.R | 96 +++++++++++++++- R/pkg/R/SQLContext.R | 4 +- R/pkg/R/generics.R | 6 + R/pkg/tests/fulltests/test_streaming.R | 107 ++++++++++++++++++ python/pyspark/sql/streaming.py | 4 + .../sql/execution/streaming/Triggers.scala | 2 +- 7 files changed, 214 insertions(+), 6 deletions(-) diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 3219c6f0cc47b..c51eb0f39c4b1 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -179,6 +179,7 @@ exportMethods("arrange", "with", "withColumn", "withColumnRenamed", + "withWatermark", "write.df", "write.jdbc", "write.json", diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fe238f6dd4eb0..9956f7eda91e6 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -3661,7 +3661,8 @@ setMethod("getNumPartitions", #' isStreaming #' #' Returns TRUE if this SparkDataFrame contains one or more sources that continuously return data -#' as it arrives. +#' as it arrives. A dataset that reads data from a streaming source must be executed as a +#' \code{StreamingQuery} using \code{write.stream}. #' #' @param x A SparkDataFrame #' @return TRUE if this SparkDataFrame is from a streaming source @@ -3707,7 +3708,17 @@ setMethod("isStreaming", #' @param df a streaming SparkDataFrame. #' @param source a name for external data source. #' @param outputMode one of 'append', 'complete', 'update'. -#' @param ... additional argument(s) passed to the method. +#' @param partitionBy a name or a list of names of columns to partition the output by on the file +#' system. If specified, the output is laid out on the file system similar to Hive's +#' partitioning scheme. +#' @param trigger.processingTime a processing time interval as a string, e.g. '5 seconds', +#' '1 minute'. This is a trigger that runs a query periodically based on the processing +#' time. If value is '0 seconds', the query will run as fast as possible, this is the +#' default. Only one trigger can be set. +#' @param trigger.once a logical, must be set to \code{TRUE}. This is a trigger that processes only +#' one batch of data in a streaming query then terminates the query. Only one trigger can be +#' set. +#' @param ... additional external data source specific named options. #' #' @family SparkDataFrame functions #' @seealso \link{read.stream} @@ -3725,7 +3736,8 @@ setMethod("isStreaming", #' # console #' q <- write.stream(wordCounts, "console", outputMode = "complete") #' # text stream -#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp") +#' q <- write.stream(df, "text", path = "/home/user/out", checkpointLocation = "/home/user/cp" +#' partitionBy = c("year", "month"), trigger.processingTime = "30 seconds") #' # memory stream #' q <- write.stream(wordCounts, "memory", queryName = "outs", outputMode = "complete") #' head(sql("SELECT * from outs")) @@ -3737,7 +3749,8 @@ setMethod("isStreaming", #' @note experimental setMethod("write.stream", signature(df = "SparkDataFrame"), - function(df, source = NULL, outputMode = NULL, ...) { + function(df, source = NULL, outputMode = NULL, partitionBy = NULL, + trigger.processingTime = NULL, trigger.once = NULL, ...) { if (!is.null(source) && !is.character(source)) { stop("source should be character, NULL or omitted. It is the data source specified ", "in 'spark.sql.sources.default' configuration by default.") @@ -3748,12 +3761,43 @@ setMethod("write.stream", if (is.null(source)) { source <- getDefaultSqlSource() } + cols <- NULL + if (!is.null(partitionBy)) { + if (!all(sapply(partitionBy, function(c) { is.character(c) }))) { + stop("All partitionBy column names should be characters.") + } + cols <- as.list(partitionBy) + } + jtrigger <- NULL + if (!is.null(trigger.processingTime) && !is.na(trigger.processingTime)) { + if (!is.null(trigger.once)) { + stop("Multiple triggers not allowed.") + } + interval <- as.character(trigger.processingTime) + if (nchar(interval) == 0) { + stop("Value for trigger.processingTime must be a non-empty string.") + } + jtrigger <- handledCallJStatic("org.apache.spark.sql.streaming.Trigger", + "ProcessingTime", + interval) + } else if (!is.null(trigger.once) && !is.na(trigger.once)) { + if (!is.logical(trigger.once) || !trigger.once) { + stop("Value for trigger.once must be TRUE.") + } + jtrigger <- callJStatic("org.apache.spark.sql.streaming.Trigger", "Once") + } options <- varargsToStrEnv(...) write <- handledCallJMethod(df@sdf, "writeStream") write <- callJMethod(write, "format", source) if (!is.null(outputMode)) { write <- callJMethod(write, "outputMode", outputMode) } + if (!is.null(cols)) { + write <- callJMethod(write, "partitionBy", cols) + } + if (!is.null(jtrigger)) { + write <- callJMethod(write, "trigger", jtrigger) + } write <- callJMethod(write, "options", options) ssq <- handledCallJMethod(write, "start") streamingQuery(ssq) @@ -3967,3 +4011,47 @@ setMethod("broadcast", sdf <- callJStatic("org.apache.spark.sql.functions", "broadcast", x@sdf) dataFrame(sdf) }) + +#' withWatermark +#' +#' Defines an event time watermark for this streaming SparkDataFrame. A watermark tracks a point in +#' time before which we assume no more late data is going to arrive. +#' +#' Spark will use this watermark for several purposes: +#' \itemize{ +#' \item{-} To know when a given time window aggregation can be finalized and thus can be emitted +#' when using output modes that do not allow updates. +#' \item{-} To minimize the amount of state that we need to keep for on-going aggregations. +#' } +#' The current watermark is computed by looking at the \code{MAX(eventTime)} seen across +#' all of the partitions in the query minus a user specified \code{delayThreshold}. Due to the cost +#' of coordinating this value across partitions, the actual watermark used is only guaranteed +#' to be at least \code{delayThreshold} behind the actual event time. In some cases we may still +#' process records that arrive more than \code{delayThreshold} late. +#' +#' @param x a streaming SparkDataFrame +#' @param eventTime a string specifying the name of the Column that contains the event time of the +#' row. +#' @param delayThreshold a string specifying the minimum delay to wait to data to arrive late, +#' relative to the latest record that has been processed in the form of an +#' interval (e.g. "1 minute" or "5 hours"). NOTE: This should not be negative. +#' @return a SparkDataFrame. +#' @aliases withWatermark,SparkDataFrame,character,character-method +#' @family SparkDataFrame functions +#' @rdname withWatermark +#' @name withWatermark +#' @export +#' @examples +#' \dontrun{ +#' sparkR.session() +#' schema <- structType(structField("time", "timestamp"), structField("value", "double")) +#' df <- read.stream("json", path = jsonDir, schema = schema, maxFilesPerTrigger = 1) +#' df <- withWatermark(df, "time", "10 minutes") +#' } +#' @note withWatermark since 2.3.0 +setMethod("withWatermark", + signature(x = "SparkDataFrame", eventTime = "character", delayThreshold = "character"), + function(x, eventTime, delayThreshold) { + sdf <- callJMethod(x@sdf, "withWatermark", eventTime, delayThreshold) + dataFrame(sdf) + }) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 3b7f71bbbffb8..9d0a2d5e074e4 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -727,7 +727,9 @@ read.jdbc <- function(url, tableName, #' @param schema The data schema defined in structType or a DDL-formatted string, this is #' required for file-based streaming data source #' @param ... additional external data source specific named options, for instance \code{path} for -#' file-based streaming data source +#' file-based streaming data source. \code{timeZone} to indicate a timezone to be used to +#' parse timestamps in the JSON/CSV data sources or partition values; If it isn't set, it +#' uses the default value, session local timezone. #' @return SparkDataFrame #' @rdname read.stream #' @name read.stream diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index 5369c32544e5e..e0dde3339fabc 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -799,6 +799,12 @@ setGeneric("withColumn", function(x, colName, col) { standardGeneric("withColumn setGeneric("withColumnRenamed", function(x, existingCol, newCol) { standardGeneric("withColumnRenamed") }) +#' @rdname withWatermark +#' @export +setGeneric("withWatermark", function(x, eventTime, delayThreshold) { + standardGeneric("withWatermark") +}) + #' @rdname write.df #' @export setGeneric("write.df", function(df, path = NULL, ...) { standardGeneric("write.df") }) diff --git a/R/pkg/tests/fulltests/test_streaming.R b/R/pkg/tests/fulltests/test_streaming.R index 54f40bbd5f517..a354d50c6b54e 100644 --- a/R/pkg/tests/fulltests/test_streaming.R +++ b/R/pkg/tests/fulltests/test_streaming.R @@ -172,6 +172,113 @@ test_that("Terminated by error", { stopQuery(q) }) +test_that("PartitionBy", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + checkpointPath <- tempfile(pattern = "sparkr-test", fileext = ".checkpoint") + textPath <- tempfile(pattern = "sparkr-test", fileext = ".text") + df <- read.df(jsonPath, "json", stringSchema) + write.df(df, parquetPath, "parquet", "overwrite") + + df <- read.stream(path = parquetPath, schema = stringSchema) + + expect_error(write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = c(1, 2)), + "All partitionBy column names should be characters") + + q <- write.stream(df, "json", path = textPath, checkpointLocation = "append", + partitionBy = "name") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + dirs <- list.files(textPath) + expect_equal(length(dirs[substring(dirs, 1, nchar("name=")) == "name="]), 3) + + unlink(checkpointPath) + unlink(textPath) + unlink(parquetPath) +}) + +test_that("Watermark", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + t <- Sys.time() + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + df <- withColumn(df, "eventTime", cast(df$value, "timestamp")) + df <- withWatermark(df, "eventTime", "10 seconds") + counts <- count(group_by(df, "eventTime")) + q <- write.stream(counts, "memory", queryName = "times", outputMode = "append") + + # first events + df <- as.DataFrame(lapply(list(t + 1, t, t + 2), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # advance watermark to 15 + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # old events, should be dropped + df <- as.DataFrame(lapply(list(t), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + # evict events less than previous watermark + df <- as.DataFrame(lapply(list(t + 25), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + times <- collect(sql("SELECT * FROM times")) + # looks like write timing can affect the first bucket; but it should be t + expect_equal(times[order(times$eventTime),][1, 2], 2) + + stopQuery(q) + unlink(parquetPath) +}) + +test_that("Trigger", { + parquetPath <- tempfile(pattern = "sparkr-test", fileext = ".parquet") + schema <- structType(structField("value", "string")) + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + df <- read.stream(path = parquetPath, schema = "value STRING") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "", trigger.once = ""), "Multiple triggers not allowed.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = ""), + "Value for trigger.processingTime must be a non-empty string.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.processingTime = "invalid"), "illegal argument") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = ""), "Value for trigger.once must be TRUE.") + + expect_error(write.stream(df, "memory", queryName = "times", outputMode = "append", + trigger.once = FALSE), "Value for trigger.once must be TRUE.") + + q <- write.stream(df, "memory", queryName = "times", outputMode = "append", trigger.once = TRUE) + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + df <- as.DataFrame(lapply(list(Sys.time()), as.character), schema) + write.df(df, parquetPath, "parquet", "append") + awaitTermination(q, 5 * 1000) + callJMethod(q@ssq, "processAllAvailable") + + expect_equal(nrow(collect(sql("SELECT * FROM times"))), 1) + + stopQuery(q) + unlink(parquetPath) +}) + unlink(jsonPath) unlink(jsonPathNa) diff --git a/python/pyspark/sql/streaming.py b/python/pyspark/sql/streaming.py index fb228f99ba7ab..24ae3776a217b 100644 --- a/python/pyspark/sql/streaming.py +++ b/python/pyspark/sql/streaming.py @@ -793,6 +793,10 @@ def trigger(self, processingTime=None, once=None): .. note:: Evolving. :param processingTime: a processing time interval as a string, e.g. '5 seconds', '1 minute'. + Set a trigger that runs a query periodically based on the processing + time. Only one trigger can be set. + :param once: if set to True, set a trigger that processes only one batch of data in a + streaming query then terminates the query. Only one trigger can be set. >>> # trigger the query for execution every 5 seconds >>> writer = sdf.writeStream.trigger(processingTime='5 seconds') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala index 271bc4da99c08..19e3e55cb2829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/Triggers.scala @@ -21,7 +21,7 @@ import org.apache.spark.annotation.{Experimental, InterfaceStability} import org.apache.spark.sql.streaming.Trigger /** - * A [[Trigger]] that process only one batch of data in a streaming query then terminates + * A [[Trigger]] that processes only one batch of data in a streaming query then terminates * the query. */ @Experimental From a7cfd6beaf35f79a744047a4a09714ef1da60293 Mon Sep 17 00:00:00 2001 From: Kent Yao Date: Thu, 4 Jan 2018 19:10:10 +0800 Subject: [PATCH 23/55] [SPARK-22950][SQL] Handle ChildFirstURLClassLoader's parent ## What changes were proposed in this pull request? ChildFirstClassLoader's parent is set to null, so we can't get jars from its parent. This will cause ClassNotFoundException during HiveClient initialization with builtin hive jars, where we may should use spark context loader instead. ## How was this patch tested? add new ut cc cloud-fan gatorsmile Author: Kent Yao Closes #20145 from yaooqinn/SPARK-22950. (cherry picked from commit 9fa703e89318922393bae03c0db4575f4f4b4c56) Signed-off-by: Wenchen Fan --- .../org/apache/spark/sql/hive/HiveUtils.scala | 4 +++- .../spark/sql/hive/HiveUtilsSuite.scala | 20 +++++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala index c489690af8cd1..c7717d70c996f 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveUtils.scala @@ -47,7 +47,7 @@ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf._ import org.apache.spark.sql.internal.StaticSQLConf.{CATALOG_IMPLEMENTATION, WAREHOUSE_PATH} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils +import org.apache.spark.util.{ChildFirstURLClassLoader, Utils} private[spark] object HiveUtils extends Logging { @@ -312,6 +312,8 @@ private[spark] object HiveUtils extends Logging { // starting from the given classLoader. def allJars(classLoader: ClassLoader): Array[URL] = classLoader match { case null => Array.empty[URL] + case childFirst: ChildFirstURLClassLoader => + childFirst.getURLs() ++ allJars(Utils.getSparkClassLoader) case urlClassLoader: URLClassLoader => urlClassLoader.getURLs ++ allJars(urlClassLoader.getParent) case other => allJars(other.getParent) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala index fdbfcf1a68440..8697d47e89e89 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveUtilsSuite.scala @@ -17,11 +17,16 @@ package org.apache.spark.sql.hive +import java.net.URL + import org.apache.hadoop.hive.conf.HiveConf.ConfVars +import org.apache.spark.SparkConf +import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.sql.QueryTest import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader} class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { @@ -42,4 +47,19 @@ class HiveUtilsSuite extends QueryTest with SQLTestUtils with TestHiveSingleton assert(hiveConf("foo") === "bar") } } + + test("ChildFirstURLClassLoader's parent is null, get spark classloader instead") { + val conf = new SparkConf + val contextClassLoader = Thread.currentThread().getContextClassLoader + val loader = new ChildFirstURLClassLoader(Array(), contextClassLoader) + try { + Thread.currentThread().setContextClassLoader(loader) + HiveUtils.newClientForMetadata( + conf, + SparkHadoopUtil.newConfiguration(conf), + HiveUtils.newTemporaryConfiguration(useInMemoryDerby = true)) + } finally { + Thread.currentThread().setContextClassLoader(contextClassLoader) + } + } } From eb99b8adecc050240ce9d5e0b92a20f018df465e Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Thu, 4 Jan 2018 19:17:22 +0800 Subject: [PATCH 24/55] [SPARK-22945][SQL] add java UDF APIs in the functions object ## What changes were proposed in this pull request? Currently Scala users can use UDF like ``` val foo = udf((i: Int) => Math.random() + i).asNondeterministic df.select(foo('a)) ``` Python users can also do it with similar APIs. However Java users can't do it, we should add Java UDF APIs in the functions object. ## How was this patch tested? new tests Author: Wenchen Fan Closes #20141 from cloud-fan/udf. (cherry picked from commit d5861aba9d80ca15ad3f22793b79822e470d6913) Signed-off-by: Wenchen Fan --- .../apache/spark/sql/UDFRegistration.scala | 90 ++--- .../sql/expressions/UserDefinedFunction.scala | 1 + .../org/apache/spark/sql/functions.scala | 313 ++++++++++++++---- .../apache/spark/sql/JavaDataFrameSuite.java | 11 + .../scala/org/apache/spark/sql/UDFSuite.scala | 12 +- 5 files changed, 315 insertions(+), 112 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala index dc2468a721e41..f94baef39dfad 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/UDFRegistration.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.lang.reflect.{ParameterizedType, Type} +import java.lang.reflect.ParameterizedType import scala.reflect.runtime.universe.TypeTag import scala.util.Try @@ -110,29 +110,29 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends /* register 0-22 were generated by this script - (0 to 22).map { x => + (0 to 22).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) - val typeTags = (1 to x).map(i => s"A${i}: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) + val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor[A$i].dataType :: $s"}) println(s""" - /** - * Registers a deterministic Scala closure of ${x} arguments as user-defined function (UDF). - * @tparam RT return type of UDF. - * @since 1.3.0 - */ - def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - def builder(e: Seq[Expression]) = if (e.length == $x) { - ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) - } else { - throw new AnalysisException("Invalid number of arguments for function " + name + - ". Expected: $x; Found: " + e.length) - } - functionRegistry.createOrReplaceTempFunction(name, builder) - val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Registers a deterministic Scala closure of $x arguments as user-defined function (UDF). + | * @tparam RT return type of UDF. + | * @since 1.3.0 + | */ + |def register[$typeTags](name: String, func: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | def builder(e: Seq[Expression]) = if (e.length == $x) { + | ScalaUDF(func, dataType, e, inputTypes.getOrElse(Nil), Some(name), nullable, udfDeterministic = true) + | } else { + | throw new AnalysisException("Invalid number of arguments for function " + name + + | ". Expected: $x; Found: " + e.length) + | } + | functionRegistry.createOrReplaceTempFunction(name, builder) + | val udf = UserDefinedFunction(func, dataType, inputTypes).withName(name) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) } (0 to 22).foreach { i => @@ -144,7 +144,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends val funcCall = if (i == 0) "() => func" else "func" println(s""" |/** - | * Register a user-defined function with ${i} arguments. + | * Register a deterministic Java UDF$i instance as user-defined function (UDF). | * @since $version | */ |def register(name: String, f: UDF$i[$extTypeArgs], returnType: DataType): Unit = { @@ -689,7 +689,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 0 arguments. + * Register a deterministic Java UDF0 instance as user-defined function (UDF). * @since 2.3.0 */ def register(name: String, f: UDF0[_], returnType: DataType): Unit = { @@ -704,7 +704,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 1 arguments. + * Register a deterministic Java UDF1 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF1[_, _], returnType: DataType): Unit = { @@ -719,7 +719,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 2 arguments. + * Register a deterministic Java UDF2 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF2[_, _, _], returnType: DataType): Unit = { @@ -734,7 +734,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 3 arguments. + * Register a deterministic Java UDF3 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF3[_, _, _, _], returnType: DataType): Unit = { @@ -749,7 +749,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 4 arguments. + * Register a deterministic Java UDF4 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF4[_, _, _, _, _], returnType: DataType): Unit = { @@ -764,7 +764,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 5 arguments. + * Register a deterministic Java UDF5 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF5[_, _, _, _, _, _], returnType: DataType): Unit = { @@ -779,7 +779,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 6 arguments. + * Register a deterministic Java UDF6 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF6[_, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -794,7 +794,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 7 arguments. + * Register a deterministic Java UDF7 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -809,7 +809,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 8 arguments. + * Register a deterministic Java UDF8 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -824,7 +824,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 9 arguments. + * Register a deterministic Java UDF9 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -839,7 +839,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 10 arguments. + * Register a deterministic Java UDF10 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -854,7 +854,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 11 arguments. + * Register a deterministic Java UDF11 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF11[_, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -869,7 +869,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 12 arguments. + * Register a deterministic Java UDF12 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF12[_, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -884,7 +884,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 13 arguments. + * Register a deterministic Java UDF13 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF13[_, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -899,7 +899,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 14 arguments. + * Register a deterministic Java UDF14 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF14[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -914,7 +914,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 15 arguments. + * Register a deterministic Java UDF15 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF15[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -929,7 +929,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 16 arguments. + * Register a deterministic Java UDF16 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF16[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -944,7 +944,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 17 arguments. + * Register a deterministic Java UDF17 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF17[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -959,7 +959,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 18 arguments. + * Register a deterministic Java UDF18 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF18[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -974,7 +974,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 19 arguments. + * Register a deterministic Java UDF19 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF19[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -989,7 +989,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 20 arguments. + * Register a deterministic Java UDF20 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF20[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1004,7 +1004,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 21 arguments. + * Register a deterministic Java UDF21 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF21[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { @@ -1019,7 +1019,7 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends } /** - * Register a user-defined function with 22 arguments. + * Register a deterministic Java UDF22 instance as user-defined function (UDF). * @since 1.3.0 */ def register(name: String, f: UDF22[_, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _, _], returnType: DataType): Unit = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 03b654f830520..40a058d2cadd2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -66,6 +66,7 @@ case class UserDefinedFunction protected[sql] ( * * @since 1.3.0 */ + @scala.annotation.varargs def apply(exprs: Column*): Column = { Column(ScalaUDF( f, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 530a525a01dec..0d11682d80a3c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -24,6 +24,7 @@ import scala.util.Try import scala.util.control.NonFatal import org.apache.spark.annotation.InterfaceStability +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.analysis.{Star, UnresolvedFunction} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -32,7 +33,6 @@ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.catalyst.plans.logical.{HintInfo, ResolvedHint} import org.apache.spark.sql.execution.SparkSqlParser import org.apache.spark.sql.expressions.UserDefinedFunction -import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.util.Utils @@ -3254,42 +3254,66 @@ object functions { */ def map_values(e: Column): Column = withExpr { MapValues(e.expr) } - ////////////////////////////////////////////////////////////////////////////////////////////// - ////////////////////////////////////////////////////////////////////////////////////////////// - // scalastyle:off line.size.limit // scalastyle:off parameter.number /* Use the following code to generate: - (0 to 10).map { x => + + (0 to 10).foreach { x => val types = (1 to x).foldRight("RT")((i, s) => {s"A$i, $s"}) val typeTags = (1 to x).map(i => s"A$i: TypeTag").foldLeft("RT: TypeTag")(_ + ", " + _) val inputTypes = (1 to x).foldRight("Nil")((i, s) => {s"ScalaReflection.schemaFor(typeTag[A$i]).dataType :: $s"}) println(s""" - /** - * Defines a deterministic user-defined function of ${x} arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. - * - * @group udf_funcs - * @since 1.3.0 - */ - def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { - val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] - val inputTypes = Try($inputTypes).toOption - val udf = UserDefinedFunction(f, dataType, inputTypes) - if (nullable) udf else udf.asNonNullable() - }""") + |/** + | * Defines a Scala closure of $x arguments as user-defined function (UDF). + | * The data types are automatically inferred based on the Scala closure's + | * signature. By default the returned UDF is deterministic. To change it to + | * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 1.3.0 + | */ + |def udf[$typeTags](f: Function$x[$types]): UserDefinedFunction = { + | val ScalaReflection.Schema(dataType, nullable) = ScalaReflection.schemaFor[RT] + | val inputTypes = Try($inputTypes).toOption + | val udf = UserDefinedFunction(f, dataType, inputTypes) + | if (nullable) udf else udf.asNonNullable() + |}""".stripMargin) + } + + (0 to 10).foreach { i => + val extTypeArgs = (0 to i).map(_ => "_").mkString(", ") + val anyTypeArgs = (0 to i).map(_ => "Any").mkString(", ") + val anyCast = s".asInstanceOf[UDF$i[$anyTypeArgs]]" + val anyParams = (1 to i).map(_ => "_: Any").mkString(", ") + val funcCall = if (i == 0) "() => func" else "func" + println(s""" + |/** + | * Defines a Java UDF$i instance as user-defined function (UDF). + | * The caller must specify the output data type, and there is no automatic input type coercion. + | * By default the returned UDF is deterministic. To change it to nondeterministic, call the + | * API `UserDefinedFunction.asNondeterministic()`. + | * + | * @group udf_funcs + | * @since 2.3.0 + | */ + |def udf(f: UDF$i[$extTypeArgs], returnType: DataType): UserDefinedFunction = { + | val func = f$anyCast.call($anyParams) + | UserDefinedFunction($funcCall, returnType, inputTypes = None) + |}""".stripMargin) } */ + ////////////////////////////////////////////////////////////////////////////////////////////// + // Scala UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + /** - * Defines a deterministic user-defined function of 0 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 0 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3302,10 +3326,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 1 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 1 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3318,10 +3342,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 2 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 2 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3334,10 +3358,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 3 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 3 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3350,10 +3374,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 4 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 4 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3366,10 +3390,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 5 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 5 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3382,10 +3406,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 6 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 6 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3398,10 +3422,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 7 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 7 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3414,10 +3438,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 8 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 8 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3430,10 +3454,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 9 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 9 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3446,10 +3470,10 @@ object functions { } /** - * Defines a deterministic user-defined function of 10 arguments as user-defined - * function (UDF). The data types are automatically inferred based on the function's - * signature. To change a UDF to nondeterministic, call the API - * `UserDefinedFunction.asNondeterministic()`. + * Defines a Scala closure of 10 arguments as user-defined function (UDF). + * The data types are automatically inferred based on the Scala closure's + * signature. By default the returned UDF is deterministic. To change it to + * nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. * * @group udf_funcs * @since 1.3.0 @@ -3461,13 +3485,172 @@ object functions { if (nullable) udf else udf.asNonNullable() } + ////////////////////////////////////////////////////////////////////////////////////////////// + // Java UDF functions + ////////////////////////////////////////////////////////////////////////////////////////////// + + /** + * Defines a Java UDF0 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF0[_], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF0[Any]].call() + UserDefinedFunction(() => func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF1 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF1[_, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF1[Any, Any]].call(_: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF2 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF2[_, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF2[Any, Any, Any]].call(_: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF3 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF3[_, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF3[Any, Any, Any, Any]].call(_: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF4 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF4[_, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF4[Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF5 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF5[_, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF5[Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF6 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF6[_, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF6[Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF7 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF7[_, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF7[Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF8 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF8[_, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF8[Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF9 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF9[_, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF9[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + + /** + * Defines a Java UDF10 instance as user-defined function (UDF). + * The caller must specify the output data type, and there is no automatic input type coercion. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. + * + * @group udf_funcs + * @since 2.3.0 + */ + def udf(f: UDF10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType): UserDefinedFunction = { + val func = f.asInstanceOf[UDF10[Any, Any, Any, Any, Any, Any, Any, Any, Any, Any, Any]].call(_: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any, _: Any) + UserDefinedFunction(func, returnType, inputTypes = None) + } + // scalastyle:on parameter.number // scalastyle:on line.size.limit /** * Defines a deterministic user-defined function (UDF) using a Scala closure. For this variant, * the caller must specify the output data type, and there is no automatic input type coercion. - * To change a UDF to nondeterministic, call the API `UserDefinedFunction.asNondeterministic()`. + * By default the returned UDF is deterministic. To change it to nondeterministic, call the + * API `UserDefinedFunction.asNondeterministic()`. * * @param f A closure in Scala * @param dataType The output data type of the UDF diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index b007093dad84b..4f8a31f185724 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -36,6 +36,7 @@ import org.apache.spark.sql.Dataset; import org.apache.spark.sql.Row; import org.apache.spark.sql.RowFactory; +import org.apache.spark.sql.expressions.UserDefinedFunction; import org.apache.spark.sql.test.TestSparkSession; import org.apache.spark.sql.types.*; import org.apache.spark.util.sketch.BloomFilter; @@ -455,4 +456,14 @@ public void testCircularReferenceBean() { CircularReference1Bean bean = new CircularReference1Bean(); spark.createDataFrame(Arrays.asList(bean), CircularReference1Bean.class); } + + @Test + public void testUDF() { + UserDefinedFunction foo = udf((Integer i, String s) -> i.toString() + s, DataTypes.StringType); + Dataset df = spark.table("testData").select(foo.apply(col("key"), col("value"))); + String[] result = df.collectAsList().stream().map(row -> row.getString(0)).toArray(String[]::new); + String[] expected = spark.table("testData").collectAsList().stream() + .map(row -> row.get(0).toString() + row.getString(1)).toArray(String[]::new); + Assert.assertArrayEquals(expected, result); + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 7f1c009ca6e7a..db37be68e42e6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,12 +17,13 @@ package org.apache.spark.sql +import org.apache.spark.sql.api.java._ import org.apache.spark.sql.catalyst.plans.logical.Project import org.apache.spark.sql.execution.command.ExplainCommand -import org.apache.spark.sql.functions.{col, udf} +import org.apache.spark.sql.functions.udf import org.apache.spark.sql.test.SharedSQLContext import org.apache.spark.sql.test.SQLTestData._ -import org.apache.spark.sql.types.DataTypes +import org.apache.spark.sql.types.{DataTypes, DoubleType} private case class FunctionResult(f1: String, f2: String) @@ -128,6 +129,13 @@ class UDFSuite extends QueryTest with SharedSQLContext { val df2 = testData.select(bar()) assert(df2.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) assert(df2.head().getDouble(0) >= 0.0) + + val javaUdf = udf(new UDF0[Double] { + override def call(): Double = Math.random() + }, DoubleType).asNondeterministic() + val df3 = testData.select(javaUdf()) + assert(df3.logicalPlan.asInstanceOf[Project].projectList.forall(!_.deterministic)) + assert(df3.head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { From 1f5e3540c7535ceaea66ebd5ee2f598e8b3ba1a5 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Thu, 4 Jan 2018 21:07:31 +0800 Subject: [PATCH 25/55] [SPARK-22939][PYSPARK] Support Spark UDF in registerFunction ## What changes were proposed in this pull request? ```Python import random from pyspark.sql.functions import udf from pyspark.sql.types import IntegerType, StringType random_udf = udf(lambda: int(random.random() * 100), IntegerType()).asNondeterministic() spark.catalog.registerFunction("random_udf", random_udf, StringType()) spark.sql("SELECT random_udf()").collect() ``` We will get the following error. ``` Py4JError: An error occurred while calling o29.__getnewargs__. Trace: py4j.Py4JException: Method __getnewargs__([]) does not exist at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:318) at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:326) at py4j.Gateway.invoke(Gateway.java:274) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:214) at java.lang.Thread.run(Thread.java:745) ``` This PR is to support it. ## How was this patch tested? WIP Author: gatorsmile Closes #20137 from gatorsmile/registerFunction. (cherry picked from commit 5aadbc929cb194e06dbd3bab054a161569289af5) Signed-off-by: gatorsmile --- python/pyspark/sql/catalog.py | 27 +++++++++++++++---- python/pyspark/sql/context.py | 16 +++++++++--- python/pyspark/sql/tests.py | 49 +++++++++++++++++++++++++---------- python/pyspark/sql/udf.py | 21 ++++++++++----- 4 files changed, 84 insertions(+), 29 deletions(-) diff --git a/python/pyspark/sql/catalog.py b/python/pyspark/sql/catalog.py index 659bc65701a0c..156603128d063 100644 --- a/python/pyspark/sql/catalog.py +++ b/python/pyspark/sql/catalog.py @@ -227,15 +227,15 @@ def dropGlobalTempView(self, viewName): @ignore_unicode_prefix @since(2.0) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -255,9 +255,26 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = spark.udf.register("stringLengthInt", len, IntegerType()) >>> spark.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = spark.catalog.registerFunction("random_udf", random_udf, StringType()) + >>> spark.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> spark.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ - udf = UserDefinedFunction(f, returnType=returnType, name=name, - evalType=PythonEvalType.SQL_BATCHED_UDF) + + # This is to check whether the input function is a wrapped/native UserDefinedFunction + if hasattr(f, 'asNondeterministic'): + udf = UserDefinedFunction(f.func, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=f.deterministic) + else: + udf = UserDefinedFunction(f, returnType=returnType, name=name, + evalType=PythonEvalType.SQL_BATCHED_UDF) self._jsparkSession.udf().registerPython(name, udf._judf) return udf._wrapped() diff --git a/python/pyspark/sql/context.py b/python/pyspark/sql/context.py index b1e723cdecef3..b8d86cc098e94 100644 --- a/python/pyspark/sql/context.py +++ b/python/pyspark/sql/context.py @@ -175,15 +175,15 @@ def range(self, start, end=None, step=1, numPartitions=None): @ignore_unicode_prefix @since(1.2) def registerFunction(self, name, f, returnType=StringType()): - """Registers a python function (including lambda function) as a UDF - so it can be used in SQL statements. + """Registers a Python function (including lambda function) or a :class:`UserDefinedFunction` + as a UDF. The registered UDF can be used in SQL statement. In addition to a name and the function itself, the return type can be optionally specified. When the return type is not given it default to a string and conversion will automatically be done. For any other return type, the produced object must match the specified type. :param name: name of the UDF - :param f: python function + :param f: a Python function, or a wrapped/native UserDefinedFunction :param returnType: a :class:`pyspark.sql.types.DataType` object :return: a wrapped :class:`UserDefinedFunction` @@ -203,6 +203,16 @@ def registerFunction(self, name, f, returnType=StringType()): >>> _ = sqlContext.udf.register("stringLengthInt", lambda x: len(x), IntegerType()) >>> sqlContext.sql("SELECT stringLengthInt('test')").collect() [Row(stringLengthInt(test)=4)] + + >>> import random + >>> from pyspark.sql.functions import udf + >>> from pyspark.sql.types import IntegerType, StringType + >>> random_udf = udf(lambda: random.randint(0, 100), IntegerType()).asNondeterministic() + >>> newRandom_udf = sqlContext.registerFunction("random_udf", random_udf, StringType()) + >>> sqlContext.sql("SELECT random_udf()").collect() # doctest: +SKIP + [Row(random_udf()=u'82')] + >>> sqlContext.range(1).select(newRandom_udf()).collect() # doctest: +SKIP + [Row(random_udf()=u'62')] """ return self.sparkSession.catalog.registerFunction(name, f, returnType) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 67bdb3d72d93b..6dc767f9ec46e 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -378,6 +378,41 @@ def test_udf2(self): [res] = self.spark.sql("SELECT strlen(a) FROM test WHERE strlen(a) > 1").collect() self.assertEqual(4, res[0]) + def test_udf3(self): + twoargs = self.spark.catalog.registerFunction( + "twoArgs", UserDefinedFunction(lambda x, y: len(x) + y), IntegerType()) + self.assertEqual(twoargs.deterministic, True) + [row] = self.spark.sql("SELECT twoArgs('test', 1)").collect() + self.assertEqual(row[0], 5) + + def test_nondeterministic_udf(self): + from pyspark.sql.functions import udf + import random + udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() + self.assertEqual(udf_random_col.deterministic, False) + df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) + udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) + [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() + self.assertEqual(row[0] + 10, row[1]) + + def test_nondeterministic_udf2(self): + import random + from pyspark.sql.functions import udf + random_udf = udf(lambda: random.randint(6, 6), IntegerType()).asNondeterministic() + self.assertEqual(random_udf.deterministic, False) + random_udf1 = self.spark.catalog.registerFunction("randInt", random_udf, StringType()) + self.assertEqual(random_udf1.deterministic, False) + [row] = self.spark.sql("SELECT randInt()").collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf1()).collect() + self.assertEqual(row[0], "6") + [row] = self.spark.range(1).select(random_udf()).collect() + self.assertEqual(row[0], 6) + # render_doc() reproduces the help() exception without printing output + pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) + pydoc.render_doc(random_udf) + pydoc.render_doc(random_udf1) + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -435,15 +470,6 @@ def test_udf_with_array_type(self): self.assertEqual(list(range(3)), l1) self.assertEqual(1, l2) - def test_nondeterministic_udf(self): - from pyspark.sql.functions import udf - import random - udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() - df = self.spark.createDataFrame([Row(1)]).select(udf_random_col().alias('RAND')) - udf_add_ten = udf(lambda rand: rand + 10, IntegerType()) - [row] = df.withColumn('RAND_PLUS_TEN', udf_add_ten('RAND')).collect() - self.assertEqual(row[0] + 10, row[1]) - def test_broadcast_in_udf(self): bar = {"a": "aa", "b": "bb", "c": "abc"} foo = self.sc.broadcast(bar) @@ -567,7 +593,6 @@ def test_read_multiple_orc_file(self): def test_udf_with_input_file_name(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType sourceFile = udf(lambda path: path, StringType()) filePath = "python/test_support/sql/people1.json" row = self.spark.read.json(filePath).select(sourceFile(input_file_name())).first() @@ -575,7 +600,6 @@ def test_udf_with_input_file_name(self): def test_udf_with_input_file_name_for_hadooprdd(self): from pyspark.sql.functions import udf, input_file_name - from pyspark.sql.types import StringType def filename(path): return path @@ -635,7 +659,6 @@ def test_udf_with_string_return_type(self): def test_udf_shouldnt_accept_noncallable_object(self): from pyspark.sql.functions import UserDefinedFunction - from pyspark.sql.types import StringType non_callable = None self.assertRaises(TypeError, UserDefinedFunction, non_callable, StringType()) @@ -1299,7 +1322,6 @@ def test_between_function(self): df.filter(df.a.between(df.b, df.c)).collect()) def test_struct_type(self): - from pyspark.sql.types import StructType, StringType, StructField struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None) struct2 = StructType([StructField("f1", StringType(), True), StructField("f2", StringType(), True, None)]) @@ -1368,7 +1390,6 @@ def test_parse_datatype_string(self): _parse_datatype_string("a INT, c DOUBLE")) def test_metadata_null(self): - from pyspark.sql.types import StructType, StringType, StructField schema = StructType([StructField("f1", StringType(), True, None), StructField("f2", StringType(), True, {'a': None})]) rdd = self.sc.parallelize([["a", "b"], ["c", "d"]]) diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 54b5a8656e1c8..5e75eb6545333 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -56,7 +56,8 @@ def _create_udf(f, returnType, evalType): ) # Set the name of the UserDefinedFunction object to be the name of function f - udf_obj = UserDefinedFunction(f, returnType=returnType, name=None, evalType=evalType) + udf_obj = UserDefinedFunction( + f, returnType=returnType, name=None, evalType=evalType, deterministic=True) return udf_obj._wrapped() @@ -67,8 +68,10 @@ class UserDefinedFunction(object): .. versionadded:: 1.3 """ def __init__(self, func, - returnType=StringType(), name=None, - evalType=PythonEvalType.SQL_BATCHED_UDF): + returnType=StringType(), + name=None, + evalType=PythonEvalType.SQL_BATCHED_UDF, + deterministic=True): if not callable(func): raise TypeError( "Invalid function: not a function or callable (__call__ is not defined): " @@ -92,7 +95,7 @@ def __init__(self, func, func.__name__ if hasattr(func, '__name__') else func.__class__.__name__) self.evalType = evalType - self._deterministic = True + self.deterministic = deterministic @property def returnType(self): @@ -130,7 +133,7 @@ def _create_judf(self): wrapped_func = _wrap_function(sc, self.func, self.returnType) jdt = spark._jsparkSession.parseDataType(self.returnType.json()) judf = sc._jvm.org.apache.spark.sql.execution.python.UserDefinedPythonFunction( - self._name, wrapped_func, jdt, self.evalType, self._deterministic) + self._name, wrapped_func, jdt, self.evalType, self.deterministic) return judf def __call__(self, *cols): @@ -138,6 +141,9 @@ def __call__(self, *cols): sc = SparkContext._active_spark_context return Column(judf.apply(_to_seq(sc, cols, _to_java_column))) + # This function is for improving the online help system in the interactive interpreter. + # For example, the built-in help / pydoc.help. It wraps the UDF with the docstring and + # argument annotation. (See: SPARK-19161) def _wrapped(self): """ Wrap this udf with a function and attach docstring from func @@ -162,7 +168,8 @@ def wrapper(*args): wrapper.func = self.func wrapper.returnType = self.returnType wrapper.evalType = self.evalType - wrapper.asNondeterministic = self.asNondeterministic + wrapper.deterministic = self.deterministic + wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() return wrapper @@ -172,5 +179,5 @@ def asNondeterministic(self): .. versionadded:: 2.3 """ - self._deterministic = False + self.deterministic = False return self From bcfeef5a944d56af1a5106f5c07296ea2c262991 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Thu, 4 Jan 2018 21:15:10 +0800 Subject: [PATCH 26/55] [SPARK-22771][SQL] Add a missing return statement in Concat.checkInputDataTypes ## What changes were proposed in this pull request? This pr is a follow-up to fix a bug left in #19977. ## How was this patch tested? Added tests in `StringExpressionsSuite`. Author: Takeshi Yamamuro Closes #20149 from maropu/SPARK-22771-FOLLOWUP. (cherry picked from commit 6f68316e98fad72b171df422566e1fc9a7bbfcde) Signed-off-by: gatorsmile --- .../sql/catalyst/expressions/stringExpressions.scala | 2 +- .../expressions/StringExpressionsSuite.scala | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index b0da55a4a961b..41dc762154a4c 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -58,7 +58,7 @@ case class Concat(children: Seq[Expression]) extends Expression { } else { val childTypes = children.map(_.dataType) if (childTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { - TypeCheckResult.TypeCheckFailure( + return TypeCheckResult.TypeCheckFailure( s"input to function $prettyName should have StringType or BinaryType, but it's " + childTypes.map(_.simpleString).mkString("[", ", ", "]")) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala index 54cde77176e27..97ddbeba2c5ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala @@ -51,6 +51,18 @@ class StringExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { checkEvaluation(Concat(strs.map(Literal.create(_, StringType))), strs.mkString, EmptyRow) } + test("SPARK-22771 Check Concat.checkInputDataTypes results") { + assert(Concat(Seq.empty[Expression]).checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a") :: Literal.create("b") :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create("a".getBytes) :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isSuccess) + assert(Concat(Literal.create(1) :: Literal.create(2) :: Nil) + .checkInputDataTypes().isFailure) + assert(Concat(Literal.create("a") :: Literal.create("b".getBytes) :: Nil) + .checkInputDataTypes().isFailure) + } + test("concat_ws") { def testConcatWs(expected: String, sep: String, inputs: Any*): Unit = { val inputExprs = inputs.map { From cd92913f345c8d932d3c651626c7f803e6abdcdb Mon Sep 17 00:00:00 2001 From: jerryshao Date: Thu, 4 Jan 2018 11:39:42 -0800 Subject: [PATCH 27/55] [SPARK-21475][CORE][2ND ATTEMPT] Change to use NIO's Files API for external shuffle service ## What changes were proposed in this pull request? This PR is the second attempt of #18684 , NIO's Files API doesn't override `skip` method for `InputStream`, so it will bring in performance issue (mentioned in #20119). But using `FileInputStream`/`FileOutputStream` will also bring in memory issue (https://dzone.com/articles/fileinputstream-fileoutputstream-considered-harmful), which is severe for long running external shuffle service. So here in this proposal, only fixing the external shuffle service related code. ## How was this patch tested? Existing tests. Author: jerryshao Closes #20144 from jerryshao/SPARK-21475-v2. (cherry picked from commit 93f92c0ed7442a4382e97254307309977ff676f8) Signed-off-by: Shixiong Zhu --- .../apache/spark/network/buffer/FileSegmentManagedBuffer.java | 3 ++- .../apache/spark/network/shuffle/ShuffleIndexInformation.java | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java index c20fab83c3460..8b8f9892847c3 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java +++ b/common/network-common/src/main/java/org/apache/spark/network/buffer/FileSegmentManagedBuffer.java @@ -24,6 +24,7 @@ import java.io.RandomAccessFile; import java.nio.ByteBuffer; import java.nio.channels.FileChannel; +import java.nio.file.StandardOpenOption; import com.google.common.base.Objects; import com.google.common.io.ByteStreams; @@ -132,7 +133,7 @@ public Object convertToNetty() throws IOException { if (conf.lazyFileDescriptor()) { return new DefaultFileRegion(file, offset, length); } else { - FileChannel fileChannel = new FileInputStream(file).getChannel(); + FileChannel fileChannel = FileChannel.open(file.toPath(), StandardOpenOption.READ); return new DefaultFileRegion(fileChannel, offset, length); } } diff --git a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java index eacf485344b76..386738ece51a6 100644 --- a/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java +++ b/common/network-shuffle/src/main/java/org/apache/spark/network/shuffle/ShuffleIndexInformation.java @@ -19,10 +19,10 @@ import java.io.DataInputStream; import java.io.File; -import java.io.FileInputStream; import java.io.IOException; import java.nio.ByteBuffer; import java.nio.LongBuffer; +import java.nio.file.Files; /** * Keeps the index information for a particular map output @@ -39,7 +39,7 @@ public ShuffleIndexInformation(File indexFile) throws IOException { offsets = buffer.asLongBuffer(); DataInputStream dis = null; try { - dis = new DataInputStream(new FileInputStream(indexFile)); + dis = new DataInputStream(Files.newInputStream(indexFile.toPath())); dis.readFully(buffer.array()); } finally { if (dis != null) { From bc4bef472de0e99f74a80954d694c3d1744afe3a Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:19:00 -0600 Subject: [PATCH 28/55] [SPARK-22850][CORE] Ensure queued events are delivered to all event queues. The code in LiveListenerBus was queueing events before start in the queues themselves; so in situations like the following: bus.post(someEvent) bus.addToEventLogQueue(listener) bus.start() "someEvent" would not be delivered to "listener" if that was the first listener in the queue, because the queue wouldn't exist when the event was posted. This change buffers the events before starting the bus in the bus itself, so that they can be delivered to all registered queues when the bus is started. Also tweaked the unit tests to cover the behavior above. Author: Marcelo Vanzin Closes #20039 from vanzin/SPARK-22850. (cherry picked from commit d2cddc88eac32f26b18ec26bb59e85c6f09a8c88) Signed-off-by: Imran Rashid --- .../spark/scheduler/LiveListenerBus.scala | 45 ++++++++++++++++--- .../spark/scheduler/SparkListenerSuite.scala | 21 +++++---- 2 files changed, 52 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala index 23121402b1025..ba6387a8f08ad 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/LiveListenerBus.scala @@ -62,6 +62,9 @@ private[spark] class LiveListenerBus(conf: SparkConf) { private val queues = new CopyOnWriteArrayList[AsyncEventQueue]() + // Visible for testing. + @volatile private[scheduler] var queuedEvents = new mutable.ListBuffer[SparkListenerEvent]() + /** Add a listener to queue shared by all non-internal listeners. */ def addToSharedQueue(listener: SparkListenerInterface): Unit = { addToQueue(listener, SHARED_QUEUE) @@ -125,13 +128,39 @@ private[spark] class LiveListenerBus(conf: SparkConf) { /** Post an event to all queues. */ def post(event: SparkListenerEvent): Unit = { - if (!stopped.get()) { - metrics.numEventsPosted.inc() - val it = queues.iterator() - while (it.hasNext()) { - it.next().post(event) + if (stopped.get()) { + return + } + + metrics.numEventsPosted.inc() + + // If the event buffer is null, it means the bus has been started and we can avoid + // synchronization and post events directly to the queues. This should be the most + // common case during the life of the bus. + if (queuedEvents == null) { + postToQueues(event) + return + } + + // Otherwise, need to synchronize to check whether the bus is started, to make sure the thread + // calling start() picks up the new event. + synchronized { + if (!started.get()) { + queuedEvents += event + return } } + + // If the bus was already started when the check above was made, just post directly to the + // queues. + postToQueues(event) + } + + private def postToQueues(event: SparkListenerEvent): Unit = { + val it = queues.iterator() + while (it.hasNext()) { + it.next().post(event) + } } /** @@ -149,7 +178,11 @@ private[spark] class LiveListenerBus(conf: SparkConf) { } this.sparkContext = sc - queues.asScala.foreach(_.start(sc)) + queues.asScala.foreach { q => + q.start(sc) + queuedEvents.foreach(q.post) + } + queuedEvents = null metricsSystem.registerSource(metrics) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1beb36afa95f0..da6ecb82c7e42 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -48,7 +48,7 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match bus.metrics.metricRegistry.counter(s"queue.$SHARED_QUEUE.numDroppedEvents").getCount } - private def queueSize(bus: LiveListenerBus): Int = { + private def sharedQueueSize(bus: LiveListenerBus): Int = { bus.metrics.metricRegistry.getGauges().get(s"queue.$SHARED_QUEUE.size").getValue() .asInstanceOf[Int] } @@ -73,12 +73,11 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match val conf = new SparkConf() val counter = new BasicJobCounter val bus = new LiveListenerBus(conf) - bus.addToSharedQueue(counter) // Metrics are initially empty. assert(bus.metrics.numEventsPosted.getCount === 0) assert(numDroppedEvents(bus) === 0) - assert(queueSize(bus) === 0) + assert(bus.queuedEvents.size === 0) assert(eventProcessingTimeCount(bus) === 0) // Post five events: @@ -87,7 +86,10 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Five messages should be marked as received and queued, but no messages should be posted to // listeners yet because the the listener bus hasn't been started. assert(bus.metrics.numEventsPosted.getCount === 5) - assert(queueSize(bus) === 5) + assert(bus.queuedEvents.size === 5) + + // Add the counter to the bus after messages have been queued for later delivery. + bus.addToSharedQueue(counter) assert(counter.count === 0) // Starting listener bus should flush all buffered events @@ -95,9 +97,12 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match Mockito.verify(mockMetricsSystem).registerSource(bus.metrics) bus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS) assert(counter.count === 5) - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(eventProcessingTimeCount(bus) === 5) + // After the bus is started, there should be no more queued events. + assert(bus.queuedEvents === null) + // After listener bus has stopped, posting events should not increment counter bus.stop() (1 to 5).foreach { _ => bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) } @@ -188,18 +193,18 @@ class SparkListenerSuite extends SparkFunSuite with LocalSparkContext with Match // Post a message to the listener bus and wait for processing to begin: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) listenerStarted.acquire() - assert(queueSize(bus) === 0) + assert(sharedQueueSize(bus) === 0) assert(numDroppedEvents(bus) === 0) // If we post an additional message then it should remain in the queue because the listener is // busy processing the first event: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 0) // The queue is now full, so any additional events posted to the listener will be dropped: bus.post(SparkListenerJobEnd(0, jobCompletionTime, JobSucceeded)) - assert(queueSize(bus) === 1) + assert(sharedQueueSize(bus) === 1) assert(numDroppedEvents(bus) === 1) // Allow the the remaining events to be processed so we can stop the listener bus: From 2ab4012adda941ebd637bd248f65cefdf4aaf110 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 15:00:09 -0800 Subject: [PATCH 29/55] [SPARK-22948][K8S] Move SparkPodInitContainer to correct package. Author: Marcelo Vanzin Closes #20156 from vanzin/SPARK-22948. (cherry picked from commit 95f9659abe8845f9f3f42fd7ababd79e55c52489) Signed-off-by: Marcelo Vanzin --- dev/sparktestsupport/modules.py | 2 +- .../spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala | 2 +- .../deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala | 2 +- .../docker/src/main/dockerfiles/init-container/Dockerfile | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) rename resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainer.scala (99%) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/{rest => }/k8s/SparkPodInitContainerSuite.scala (98%) diff --git a/dev/sparktestsupport/modules.py b/dev/sparktestsupport/modules.py index f834563da9dda..7164180a6a7b0 100644 --- a/dev/sparktestsupport/modules.py +++ b/dev/sparktestsupport/modules.py @@ -539,7 +539,7 @@ def __hash__(self): kubernetes = Module( name="kubernetes", dependencies=[], - source_file_regexes=["resource-managers/kubernetes/core"], + source_file_regexes=["resource-managers/kubernetes"], build_profile_flags=["-Pkubernetes"], sbt_test_goals=["kubernetes/test"] ) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala similarity index 99% rename from resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala rename to resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala index 4a4b628aedbbf..c0f08786b76a1 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainer.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/SparkPodInitContainer.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.concurrent.TimeUnit diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala similarity index 98% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala index 6c557ec4a7c9a..e0f29ecd0fb53 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/rest/k8s/SparkPodInitContainerSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SparkPodInitContainerSuite.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.rest.k8s +package org.apache.spark.deploy.k8s import java.io.File import java.util.UUID diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 055493188fcb7..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -21,4 +21,4 @@ FROM spark-base # command should be invoked from the top level directory of the Spark distribution. E.g.: # docker build -t spark-init:latest -f kubernetes/dockerfiles/init-container/Dockerfile . -ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.rest.k8s.SparkPodInitContainer" ] +ENTRYPOINT [ "/opt/entrypoint.sh", "/opt/spark/bin/spark-class", "org.apache.spark.deploy.k8s.SparkPodInitContainer" ] From 84707f0c6afa9c5417e271657ff930930f82213c Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 15:35:20 -0800 Subject: [PATCH 30/55] [SPARK-22953][K8S] Avoids adding duplicated secret volumes when init-container is used ## What changes were proposed in this pull request? User-specified secrets are mounted into both the main container and init-container (when it is used) in a Spark driver/executor pod, using the `MountSecretsBootstrap`. Because `MountSecretsBootstrap` always adds new secret volumes for the secrets to the pod, the same secret volumes get added twice, one when mounting the secrets to the main container, and the other when mounting the secrets to the init-container. This PR fixes the issue by separating `MountSecretsBootstrap.mountSecrets` out into two methods: `addSecretVolumes` for adding secret volumes to a pod and `mountSecrets` for mounting secret volumes to a container, respectively. `addSecretVolumes` is only called once for each pod, whereas `mountSecrets` is called individually for the main container and the init-container (if it is used). Ref: https://github.com/apache-spark-on-k8s/spark/issues/594. ## How was this patch tested? Unit tested and manually tested. vanzin This replaces https://github.com/apache/spark/pull/20148. hex108 foxish kimoonkim Author: Yinan Li Closes #20159 from liyinan926/master. (cherry picked from commit e288fc87a027ec1e1a21401d1f151df20dbfecf3) Signed-off-by: Marcelo Vanzin --- .../deploy/k8s/MountSecretsBootstrap.scala | 30 ++++++++++++------- .../k8s/submit/DriverConfigOrchestrator.scala | 16 +++++----- .../steps/BasicDriverConfigurationStep.scala | 2 +- .../submit/steps/DriverMountSecretsStep.scala | 4 +-- .../InitContainerMountSecretsStep.scala | 11 +++---- .../cluster/k8s/ExecutorPodFactory.scala | 6 ++-- .../k8s/{submit => }/SecretVolumeUtils.scala | 18 +++++------ .../BasicDriverConfigurationStepSuite.scala | 4 +-- .../steps/DriverMountSecretsStepSuite.scala | 4 +-- .../InitContainerMountSecretsStepSuite.scala | 7 +---- .../cluster/k8s/ExecutorPodFactorySuite.scala | 14 +++++---- 11 files changed, 61 insertions(+), 55 deletions(-) rename resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/{submit => }/SecretVolumeUtils.scala (71%) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala index 8286546ce0641..c35e7db51d407 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/MountSecretsBootstrap.scala @@ -24,26 +24,36 @@ import io.fabric8.kubernetes.api.model.{Container, ContainerBuilder, Pod, PodBui private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, String]) { /** - * Mounts Kubernetes secrets as secret volumes into the given container in the given pod. + * Add new secret volumes for the secrets specified in secretNamesToMountPaths into the given pod. * * @param pod the pod into which the secret volumes are being added. - * @param container the container into which the secret volumes are being mounted. - * @return the updated pod and container with the secrets mounted. + * @return the updated pod with the secret volumes added. */ - def mountSecrets(pod: Pod, container: Container): (Pod, Container) = { + def addSecretVolumes(pod: Pod): Pod = { var podBuilder = new PodBuilder(pod) secretNamesToMountPaths.keys.foreach { name => podBuilder = podBuilder .editOrNewSpec() .addNewVolume() - .withName(secretVolumeName(name)) - .withNewSecret() - .withSecretName(name) - .endSecret() - .endVolume() + .withName(secretVolumeName(name)) + .withNewSecret() + .withSecretName(name) + .endSecret() + .endVolume() .endSpec() } + podBuilder.build() + } + + /** + * Mounts Kubernetes secret volumes of the secrets specified in secretNamesToMountPaths into the + * given container. + * + * @param container the container into which the secret volumes are being mounted. + * @return the updated container with the secrets mounted. + */ + def mountSecrets(container: Container): Container = { var containerBuilder = new ContainerBuilder(container) secretNamesToMountPaths.foreach { case (name, path) => containerBuilder = containerBuilder @@ -53,7 +63,7 @@ private[spark] class MountSecretsBootstrap(secretNamesToMountPaths: Map[String, .endVolumeMount() } - (podBuilder.build(), containerBuilder.build()) + containerBuilder.build() } private def secretVolumeName(secretName: String): String = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala index 00c9c4ee49177..c9cc300d65569 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/DriverConfigOrchestrator.scala @@ -127,6 +127,12 @@ private[spark] class DriverConfigOrchestrator( Nil } + val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { + Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) + } else { + Nil + } + val initContainerBootstrapStep = if (existNonContainerLocalFiles(sparkJars ++ sparkFiles)) { val orchestrator = new InitContainerConfigOrchestrator( sparkJars, @@ -147,19 +153,13 @@ private[spark] class DriverConfigOrchestrator( Nil } - val mountSecretsStep = if (secretNamesToMountPaths.nonEmpty) { - Seq(new DriverMountSecretsStep(new MountSecretsBootstrap(secretNamesToMountPaths))) - } else { - Nil - } - Seq( initialSubmissionStep, serviceBootstrapStep, kubernetesCredentialsStep) ++ dependencyResolutionStep ++ - initContainerBootstrapStep ++ - mountSecretsStep + mountSecretsStep ++ + initContainerBootstrapStep } private def existNonContainerLocalFiles(files: Seq[String]): Boolean = { diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala index b7a69a7dfd472..eca46b84c6066 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStep.scala @@ -119,7 +119,7 @@ private[spark] class BasicDriverConfigurationStep( .endEnv() .addNewEnv() .withName(ENV_DRIVER_ARGS) - .withValue(appArgs.map(arg => "\"" + arg + "\"").mkString(" ")) + .withValue(appArgs.mkString(" ")) .endEnv() .addNewEnv() .withName(ENV_DRIVER_BIND_ADDRESS) diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala index f872e0f4b65d1..91e9a9f211335 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStep.scala @@ -28,8 +28,8 @@ private[spark] class DriverMountSecretsStep( bootstrap: MountSecretsBootstrap) extends DriverConfigurationStep { override def configureDriver(driverSpec: KubernetesDriverSpec): KubernetesDriverSpec = { - val (pod, container) = bootstrap.mountSecrets( - driverSpec.driverPod, driverSpec.driverContainer) + val pod = bootstrap.addSecretVolumes(driverSpec.driverPod) + val container = bootstrap.mountSecrets(driverSpec.driverContainer) driverSpec.copy( driverPod = pod, driverContainer = container diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala index c0e7bb20cce8c..0daa7b95e8aae 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStep.scala @@ -28,12 +28,9 @@ private[spark] class InitContainerMountSecretsStep( bootstrap: MountSecretsBootstrap) extends InitContainerConfigurationStep { override def configureInitContainer(spec: InitContainerSpec) : InitContainerSpec = { - val (driverPod, initContainer) = bootstrap.mountSecrets( - spec.driverPod, - spec.initContainer) - spec.copy( - driverPod = driverPod, - initContainer = initContainer - ) + // Mount the secret volumes given that the volumes have already been added to the driver pod + // when mounting the secrets into the main driver container. + val initContainer = bootstrap.mountSecrets(spec.initContainer) + spec.copy(initContainer = initContainer) } } diff --git a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala index ba5d891f4c77e..066d7e9f70ca5 100644 --- a/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala +++ b/resource-managers/kubernetes/core/src/main/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactory.scala @@ -214,7 +214,7 @@ private[spark] class ExecutorPodFactory( val (maybeSecretsMountedPod, maybeSecretsMountedContainer) = mountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(executorPod, containerWithLimitCores) + (bootstrap.addSecretVolumes(executorPod), bootstrap.mountSecrets(containerWithLimitCores)) }.getOrElse((executorPod, containerWithLimitCores)) val (bootstrappedPod, bootstrappedContainer) = @@ -227,7 +227,9 @@ private[spark] class ExecutorPodFactory( val (pod, mayBeSecretsMountedInitContainer) = initContainerMountSecretsBootstrap.map { bootstrap => - bootstrap.mountSecrets(podWithInitContainer.pod, podWithInitContainer.initContainer) + // Mount the secret volumes given that the volumes have already been added to the + // executor pod when mounting the secrets into the main executor container. + (podWithInitContainer.pod, bootstrap.mountSecrets(podWithInitContainer.initContainer)) }.getOrElse((podWithInitContainer.pod, podWithInitContainer.initContainer)) val bootstrappedPod = KubernetesUtils.appendInitContainer( diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala similarity index 71% rename from resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala rename to resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala index 8388c16ded268..16780584a674a 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/SecretVolumeUtils.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/SecretVolumeUtils.scala @@ -14,7 +14,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -package org.apache.spark.deploy.k8s.submit +package org.apache.spark.deploy.k8s import scala.collection.JavaConverters._ @@ -22,15 +22,15 @@ import io.fabric8.kubernetes.api.model.{Container, Pod} private[spark] object SecretVolumeUtils { - def podHasVolume(driverPod: Pod, volumeName: String): Boolean = { - driverPod.getSpec.getVolumes.asScala.exists(volume => volume.getName == volumeName) + def podHasVolume(pod: Pod, volumeName: String): Boolean = { + pod.getSpec.getVolumes.asScala.exists { volume => + volume.getName == volumeName + } } - def containerHasVolume( - driverContainer: Container, - volumeName: String, - mountPath: String): Boolean = { - driverContainer.getVolumeMounts.asScala.exists(volumeMount => - volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath) + def containerHasVolume(container: Container, volumeName: String, mountPath: String): Boolean = { + container.getVolumeMounts.asScala.exists { volumeMount => + volumeMount.getName == volumeName && volumeMount.getMountPath == mountPath + } } } diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala index e864c6a16eeb1..8ee629ac8ddc1 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/BasicDriverConfigurationStepSuite.scala @@ -33,7 +33,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { private val CONTAINER_IMAGE_PULL_POLICY = "IfNotPresent" private val APP_NAME = "spark-test" private val MAIN_CLASS = "org.apache.spark.examples.SparkPi" - private val APP_ARGS = Array("arg1", "arg2", "arg 3") + private val APP_ARGS = Array("arg1", "arg2", "\"arg 3\"") private val CUSTOM_ANNOTATION_KEY = "customAnnotation" private val CUSTOM_ANNOTATION_VALUE = "customAnnotationValue" private val DRIVER_CUSTOM_ENV_KEY1 = "customDriverEnv1" @@ -82,7 +82,7 @@ class BasicDriverConfigurationStepSuite extends SparkFunSuite { assert(envs(ENV_SUBMIT_EXTRA_CLASSPATH) === "/opt/spark/spark-examples.jar") assert(envs(ENV_DRIVER_MEMORY) === "256M") assert(envs(ENV_DRIVER_MAIN_CLASS) === MAIN_CLASS) - assert(envs(ENV_DRIVER_ARGS) === "\"arg1\" \"arg2\" \"arg 3\"") + assert(envs(ENV_DRIVER_ARGS) === "arg1 arg2 \"arg 3\"") assert(envs(DRIVER_CUSTOM_ENV_KEY1) === "customDriverEnv1") assert(envs(DRIVER_CUSTOM_ENV_KEY2) === "customDriverEnv2") diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala index 9ec0cb55de5aa..960d0bda1d011 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/DriverMountSecretsStepSuite.scala @@ -17,8 +17,8 @@ package org.apache.spark.deploy.k8s.submit.steps import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.{KubernetesDriverSpec, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} +import org.apache.spark.deploy.k8s.submit.KubernetesDriverSpec class DriverMountSecretsStepSuite extends SparkFunSuite { diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala index eab4e17659456..7ac0bde80dfe6 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/deploy/k8s/submit/steps/initcontainer/InitContainerMountSecretsStepSuite.scala @@ -19,8 +19,7 @@ package org.apache.spark.deploy.k8s.submit.steps.initcontainer import io.fabric8.kubernetes.api.model.{ContainerBuilder, PodBuilder} import org.apache.spark.SparkFunSuite -import org.apache.spark.deploy.k8s.MountSecretsBootstrap -import org.apache.spark.deploy.k8s.submit.SecretVolumeUtils +import org.apache.spark.deploy.k8s.{MountSecretsBootstrap, SecretVolumeUtils} class InitContainerMountSecretsStepSuite extends SparkFunSuite { @@ -44,12 +43,8 @@ class InitContainerMountSecretsStepSuite extends SparkFunSuite { val initContainerMountSecretsStep = new InitContainerMountSecretsStep(mountSecretsBootstrap) val configuredInitContainerSpec = initContainerMountSecretsStep.configureInitContainer( baseInitContainerSpec) - - val podWithSecretsMounted = configuredInitContainerSpec.driverPod val initContainerWithSecretsMounted = configuredInitContainerSpec.initContainer - Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => - assert(SecretVolumeUtils.podHasVolume(podWithSecretsMounted, volumeName))) Seq(s"$SECRET_FOO-volume", s"$SECRET_BAR-volume").foreach(volumeName => assert(SecretVolumeUtils.containerHasVolume( initContainerWithSecretsMounted, volumeName, SECRET_MOUNT_PATH))) diff --git a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala index 7121a802c69c1..884da8aabd880 100644 --- a/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala +++ b/resource-managers/kubernetes/core/src/test/scala/org/apache/spark/scheduler/cluster/k8s/ExecutorPodFactorySuite.scala @@ -25,7 +25,7 @@ import org.mockito.Mockito._ import org.scalatest.{BeforeAndAfter, BeforeAndAfterEach} import org.apache.spark.{SparkConf, SparkFunSuite} -import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer} +import org.apache.spark.deploy.k8s.{InitContainerBootstrap, MountSecretsBootstrap, PodWithDetachedInitContainer, SecretVolumeUtils} import org.apache.spark.deploy.k8s.Config._ import org.apache.spark.deploy.k8s.Constants._ @@ -165,17 +165,19 @@ class ExecutorPodFactorySuite extends SparkFunSuite with BeforeAndAfter with Bef val factory = new ExecutorPodFactory( conf, - None, + Some(secretsBootstrap), Some(initContainerBootstrap), Some(secretsBootstrap)) val executor = factory.createExecutorPod( "1", "dummy", "dummy", Seq[(String, String)](), driverPod, Map[String, Int]()) + assert(executor.getSpec.getVolumes.size() === 1) + assert(SecretVolumeUtils.podHasVolume(executor, "secret1-volume")) + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getContainers.get(0), "secret1-volume", "/var/secret1")) assert(executor.getSpec.getInitContainers.size() === 1) - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0).getName - === "secret1-volume") - assert(executor.getSpec.getInitContainers.get(0).getVolumeMounts.get(0) - .getMountPath === "/var/secret1") + assert(SecretVolumeUtils.containerHasVolume( + executor.getSpec.getInitContainers.get(0), "secret1-volume", "/var/secret1")) checkOwnerReferences(executor, driverPodUid) } From ea9da6152af9223787cffd83d489741b4cc5aa34 Mon Sep 17 00:00:00 2001 From: Marcelo Vanzin Date: Thu, 4 Jan 2018 16:34:56 -0800 Subject: [PATCH 31/55] [SPARK-22960][K8S] Make build-push-docker-images.sh more dev-friendly. - Make it possible to build images from a git clone. - Make it easy to use minikube to test things. Also fixed what seemed like a bug: the base image wasn't getting the tag provided in the command line. Adding the tag allows users to use multiple Spark builds in the same kubernetes cluster. Tested by deploying images on minikube and running spark-submit from a dev environment; also by building the images with different tags and verifying "docker images" in minikube. Author: Marcelo Vanzin Closes #20154 from vanzin/SPARK-22960. (cherry picked from commit 0428368c2c5e135f99f62be20877bbbda43be310) Signed-off-by: Marcelo Vanzin --- docs/running-on-kubernetes.md | 9 +- .../src/main/dockerfiles/driver/Dockerfile | 3 +- .../src/main/dockerfiles/executor/Dockerfile | 3 +- .../dockerfiles/init-container/Dockerfile | 3 +- .../main/dockerfiles/spark-base/Dockerfile | 7 +- sbin/build-push-docker-images.sh | 120 +++++++++++++++--- 6 files changed, 117 insertions(+), 28 deletions(-) diff --git a/docs/running-on-kubernetes.md b/docs/running-on-kubernetes.md index e491329136a3c..2d69f636472ae 100644 --- a/docs/running-on-kubernetes.md +++ b/docs/running-on-kubernetes.md @@ -16,6 +16,9 @@ Kubernetes scheduler that has been added to Spark. you may setup a test cluster on your local machine using [minikube](https://kubernetes.io/docs/getting-started-guides/minikube/). * We recommend using the latest release of minikube with the DNS addon enabled. + * Be aware that the default minikube configuration is not enough for running Spark applications. + We recommend 3 CPUs and 4g of memory to be able to start a simple Spark application with a single + executor. * You must have appropriate permissions to list, create, edit and delete [pods](https://kubernetes.io/docs/user-guide/pods/) in your cluster. You can verify that you can list these resources by running `kubectl auth can-i pods`. @@ -197,7 +200,7 @@ kubectl port-forward 4040:4040 Then, the Spark driver UI can be accessed on `http://localhost:4040`. -### Debugging +### Debugging There may be several kinds of failures. If the Kubernetes API server rejects the request made from spark-submit, or the connection is refused for a different reason, the submission logic should indicate the error encountered. However, if there @@ -215,8 +218,8 @@ If the pod has encountered a runtime error, the status can be probed further usi kubectl logs ``` -Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark -application, includling all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of +Status and logs of failed executor pods can be checked in similar ways. Finally, deleting the driver pod will clean up the entire spark +application, including all executors, associated service, etc. The driver pod can be thought of as the Kubernetes representation of the Spark application. ## Kubernetes Features diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index 45fbcd9cd0deb..ff5289e10c21e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 0f806cf7e148e..3eabb42d4d852 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index 047056ab2633b..e0a249e0ac71f 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,7 +15,8 @@ # limitations under the License. # -FROM spark-base +ARG base_image +FROM ${base_image} # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile index 222e777db3a82..da1d6b9e161cc 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark-base/Dockerfile @@ -17,6 +17,9 @@ FROM openjdk:8-alpine +ARG spark_jars +ARG img_path + # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. # If this docker file is being used in the context of building your images from a Spark @@ -34,11 +37,11 @@ RUN set -ex && \ ln -sv /bin/bash /bin/sh && \ chgrp root /etc/passwd && chmod ug+rw /etc/passwd -COPY jars /opt/spark/jars +COPY ${spark_jars} /opt/spark/jars COPY bin /opt/spark/bin COPY sbin /opt/spark/sbin COPY conf /opt/spark/conf -COPY kubernetes/dockerfiles/spark-base/entrypoint.sh /opt/ +COPY ${img_path}/spark-base/entrypoint.sh /opt/ ENV SPARK_HOME /opt/spark diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index b3137598692d8..bb8806dd33f37 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -19,29 +19,94 @@ # This script builds and pushes docker images when run from a release of Spark # with Kubernetes support. -declare -A path=( [spark-driver]=kubernetes/dockerfiles/driver/Dockerfile \ - [spark-executor]=kubernetes/dockerfiles/executor/Dockerfile \ - [spark-init]=kubernetes/dockerfiles/init-container/Dockerfile ) +function error { + echo "$@" 1>&2 + exit 1 +} + +# Detect whether this is a git clone or a Spark distribution and adjust paths +# accordingly. +if [ -z "${SPARK_HOME}" ]; then + SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)" +fi +. "${SPARK_HOME}/bin/load-spark-env.sh" + +if [ -f "$SPARK_HOME/RELEASE" ]; then + IMG_PATH="kubernetes/dockerfiles" + SPARK_JARS="jars" +else + IMG_PATH="resource-managers/kubernetes/docker/src/main/dockerfiles" + SPARK_JARS="assembly/target/scala-$SPARK_SCALA_VERSION/jars" +fi + +if [ ! -d "$IMG_PATH" ]; then + error "Cannot find docker images. This script must be run from a runnable distribution of Apache Spark." +fi + +declare -A path=( [spark-driver]="$IMG_PATH/driver/Dockerfile" \ + [spark-executor]="$IMG_PATH/executor/Dockerfile" \ + [spark-init]="$IMG_PATH/init-container/Dockerfile" ) + +function image_ref { + local image="$1" + local add_repo="${2:-1}" + if [ $add_repo = 1 ] && [ -n "$REPO" ]; then + image="$REPO/$image" + fi + if [ -n "$TAG" ]; then + image="$image:$TAG" + fi + echo "$image" +} function build { - docker build -t spark-base -f kubernetes/dockerfiles/spark-base/Dockerfile . + local base_image="$(image_ref spark-base 0)" + docker build --build-arg "spark_jars=$SPARK_JARS" \ + --build-arg "img_path=$IMG_PATH" \ + -t "$base_image" \ + -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build -t ${REPO}/$image:${TAG} -f ${path[$image]} . + docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . done } - function push { for image in "${!path[@]}"; do - docker push ${REPO}/$image:${TAG} + docker push "$(image_ref $image)" done } function usage { - echo "This script must be run from a runnable distribution of Apache Spark." - echo "Usage: ./sbin/build-push-docker-images.sh -r -t build" - echo " ./sbin/build-push-docker-images.sh -r -t push" - echo "for example: ./sbin/build-push-docker-images.sh -r docker.io/myrepo -t v2.3.0 push" + cat </dev/null; then + error "Cannot find minikube." + fi + eval $(minikube docker-env) + ;; esac done -if [ -z "$REPO" ] || [ -z "$TAG" ]; then +case "${@: -1}" in + build) + build + ;; + push) + if [ -z "$REPO" ]; then + usage + exit 1 + fi + push + ;; + *) usage -else - case "${@: -1}" in - build) build;; - push) push;; - *) usage;; - esac -fi + exit 1 + ;; +esac From 158f7e6a93b5acf4ce05c97b575124fd599cf927 Mon Sep 17 00:00:00 2001 From: Juliusz Sompolski Date: Fri, 5 Jan 2018 10:16:34 +0800 Subject: [PATCH 32/55] [SPARK-22957] ApproxQuantile breaks if the number of rows exceeds MaxInt ## What changes were proposed in this pull request? 32bit Int was used for row rank. That overflowed in a dataframe with more than 2B rows. ## How was this patch tested? Added test, but ignored, as it takes 4 minutes. Author: Juliusz Sompolski Closes #20152 from juliuszsompolski/SPARK-22957. (cherry picked from commit df7fc3ef3899cadd252d2837092bebe3442d6523) Signed-off-by: Wenchen Fan --- .../aggregate/ApproximatePercentile.scala | 12 ++++++------ .../spark/sql/catalyst/util/QuantileSummaries.scala | 8 ++++---- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 ++++++++ 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala index 149ac265e6ed5..a45854a3b5146 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/ApproximatePercentile.scala @@ -296,8 +296,8 @@ object ApproximatePercentile { Ints.BYTES + Doubles.BYTES + Longs.BYTES + // length of summary.sampled Ints.BYTES + - // summary.sampled, Array[Stat(value: Double, g: Int, delta: Int)] - summaries.sampled.length * (Doubles.BYTES + Ints.BYTES + Ints.BYTES) + // summary.sampled, Array[Stat(value: Double, g: Long, delta: Long)] + summaries.sampled.length * (Doubles.BYTES + Longs.BYTES + Longs.BYTES) } final def serialize(obj: PercentileDigest): Array[Byte] = { @@ -312,8 +312,8 @@ object ApproximatePercentile { while (i < summary.sampled.length) { val stat = summary.sampled(i) buffer.putDouble(stat.value) - buffer.putInt(stat.g) - buffer.putInt(stat.delta) + buffer.putLong(stat.g) + buffer.putLong(stat.delta) i += 1 } buffer.array() @@ -330,8 +330,8 @@ object ApproximatePercentile { var i = 0 while (i < sampledLength) { val value = buffer.getDouble() - val g = buffer.getInt() - val delta = buffer.getInt() + val g = buffer.getLong() + val delta = buffer.getLong() sampled(i) = Stats(value, g, delta) i += 1 } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala index eb7941cf9e6af..b013add9c9778 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/QuantileSummaries.scala @@ -105,7 +105,7 @@ class QuantileSummaries( if (newSamples.isEmpty || (sampleIdx == sampled.length && opsIdx == sorted.length - 1)) { 0 } else { - math.floor(2 * relativeError * currentCount).toInt + math.floor(2 * relativeError * currentCount).toLong } val tuple = Stats(currentSample, 1, delta) @@ -192,10 +192,10 @@ class QuantileSummaries( } // Target rank - val rank = math.ceil(quantile * count).toInt + val rank = math.ceil(quantile * count).toLong val targetError = relativeError * count // Minimum rank at current sample - var minRank = 0 + var minRank = 0L var i = 0 while (i < sampled.length - 1) { val curSample = sampled(i) @@ -235,7 +235,7 @@ object QuantileSummaries { * @param g the minimum rank jump from the previous value's minimum rank * @param delta the maximum span of the rank. */ - case class Stats(value: Double, g: Int, delta: Int) + case class Stats(value: Double, g: Long, delta: Long) private def compressImmut( currentSamples: IndexedSeq[Stats], diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 46b21c3b64a2e..5169d2b5fc6b2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -260,6 +260,14 @@ class DataFrameStatSuite extends QueryTest with SharedSQLContext { assert(res2(1).isEmpty) } + // SPARK-22957: check for 32bit overflow when computing rank. + // ignored - takes 4 minutes to run. + ignore("approx quantile 4: test for Int overflow") { + val res = spark.range(3000000000L).stat.approxQuantile("id", Array(0.8, 0.9), 0.05) + assert(res(0) > 2200000000.0) + assert(res(1) > 2200000000.0) + } + test("crosstab") { withSQLConf(SQLConf.SUPPORT_QUOTED_REGEX_COLUMN_NAME.key -> "false") { val rng = new Random() From 145820bda140d1385c4dd802fa79a871e6bf98be Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Fri, 5 Jan 2018 14:02:21 +0800 Subject: [PATCH 33/55] [SPARK-22825][SQL] Fix incorrect results of Casting Array to String ## What changes were proposed in this pull request? This pr fixed the issue when casting arrays into strings; ``` scala> val df = spark.range(10).select('id.cast("integer")).agg(collect_list('id).as('ids)) scala> df.write.saveAsTable("t") scala> sql("SELECT cast(ids as String) FROM t").show(false) +------------------------------------------------------------------+ |ids | +------------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeArrayData8bc285df| +------------------------------------------------------------------+ ``` This pr modified the result into; ``` +------------------------------+ |ids | +------------------------------+ |[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]| +------------------------------+ ``` ## How was this patch tested? Added tests in `CastSuite` and `SQLQuerySuite`. Author: Takeshi Yamamuro Closes #20024 from maropu/SPARK-22825. (cherry picked from commit 52fc5c17d9d784b846149771b398e741621c0b5c) Signed-off-by: Wenchen Fan --- .../codegen/UTF8StringBuilder.java | 78 +++++++++++++++++++ .../spark/sql/catalyst/expressions/Cast.scala | 68 ++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 25 ++++++ .../org/apache/spark/sql/SQLQuerySuite.scala | 2 - 4 files changed, 171 insertions(+), 2 deletions(-) create mode 100644 sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java new file mode 100644 index 0000000000000..f0f66bae245fd --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/catalyst/expressions/codegen/UTF8StringBuilder.java @@ -0,0 +1,78 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.expressions.codegen; + +import org.apache.spark.unsafe.Platform; +import org.apache.spark.unsafe.array.ByteArrayMethods; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * A helper class to write {@link UTF8String}s to an internal buffer and build the concatenated + * {@link UTF8String} at the end. + */ +public class UTF8StringBuilder { + + private static final int ARRAY_MAX = ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH; + + private byte[] buffer; + private int cursor = Platform.BYTE_ARRAY_OFFSET; + + public UTF8StringBuilder() { + // Since initial buffer size is 16 in `StringBuilder`, we set the same size here + this.buffer = new byte[16]; + } + + // Grows the buffer by at least `neededSize` + private void grow(int neededSize) { + if (neededSize > ARRAY_MAX - totalSize()) { + throw new UnsupportedOperationException( + "Cannot grow internal buffer by size " + neededSize + " because the size after growing " + + "exceeds size limitation " + ARRAY_MAX); + } + final int length = totalSize() + neededSize; + if (buffer.length < length) { + int newLength = length < ARRAY_MAX / 2 ? length * 2 : ARRAY_MAX; + final byte[] tmp = new byte[newLength]; + Platform.copyMemory( + buffer, + Platform.BYTE_ARRAY_OFFSET, + tmp, + Platform.BYTE_ARRAY_OFFSET, + totalSize()); + buffer = tmp; + } + } + + private int totalSize() { + return cursor - Platform.BYTE_ARRAY_OFFSET; + } + + public void append(UTF8String value) { + grow(value.numBytes()); + value.writeToMemory(buffer, cursor); + cursor += value.numBytes(); + } + + public void append(String value) { + append(UTF8String.fromString(value)); + } + + public UTF8String build() { + return UTF8String.fromBytes(buffer, 0, totalSize()); + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index 274d8813f16db..d4fc5e0f168a7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -206,6 +206,28 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case DateType => buildCast[Int](_, d => UTF8String.fromString(DateTimeUtils.dateToString(d))) case TimestampType => buildCast[Long](_, t => UTF8String.fromString(DateTimeUtils.timestampToString(t, timeZone))) + case ArrayType(et, _) => + buildCast[ArrayData](_, array => { + val builder = new UTF8StringBuilder + builder.append("[") + if (array.numElements > 0) { + val toUTF8String = castToString(et) + if (!array.isNullAt(0)) { + builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < array.numElements) { + builder.append(",") + if (!array.isNullAt(i)) { + builder.append(" ") + builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -597,6 +619,41 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """ } + private def writeArrayToStringBuilder( + et: DataType, + array: String, + buffer: String, + ctx: CodegenContext): String = { + val elementToStringCode = castToStringCode(et, ctx) + val funcName = ctx.freshName("elementToString") + val elementToStringFunc = ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(et)} element) { + | UTF8String elementStr = null; + | ${elementToStringCode("element", "elementStr", null /* resultIsNull won't be used */)} + | return elementStr; + |} + """.stripMargin) + + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($array.numElements() > 0) { + | if (!$array.isNullAt(0)) { + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) { + | $buffer.append(","); + | if (!$array.isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($elementToStringFunc(${ctx.getValue(array, et, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -608,6 +665,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String val tz = ctx.addReferenceObj("timeZone", timeZone) (c, evPrim, evNull) => s"""$evPrim = UTF8String.fromString( org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampToString($c, $tz));""" + case ArrayType(et, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeArrayElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index 1dd040e4696a1..e3ed7171defd8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -853,4 +853,29 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { cast("2", LongType).genCode(ctx) assert(ctx.inlinedMutableStates.length == 0) } + + test("SPARK-22825 Cast array to string") { + val ret1 = cast(Literal.create(Array(1, 2, 3, 4, 5)), StringType) + checkEvaluation(ret1, "[1, 2, 3, 4, 5]") + val ret2 = cast(Literal.create(Array("ab", "cde", "f")), StringType) + checkEvaluation(ret2, "[ab, cde, f]") + val ret3 = cast(Literal.create(Array("ab", null, "c")), StringType) + checkEvaluation(ret3, "[ab,, c]") + val ret4 = cast(Literal.create(Array("ab".getBytes, "cde".getBytes, "f".getBytes)), StringType) + checkEvaluation(ret4, "[ab, cde, f]") + val ret5 = cast( + Literal.create(Array("2014-12-03", "2014-12-04", "2014-12-06").map(Date.valueOf)), + StringType) + checkEvaluation(ret5, "[2014-12-03, 2014-12-04, 2014-12-06]") + val ret6 = cast( + Literal.create(Array("2014-12-03 13:01:00", "2014-12-04 15:05:00").map(Timestamp.valueOf)), + StringType) + checkEvaluation(ret6, "[2014-12-03 13:01:00, 2014-12-04 15:05:00]") + val ret7 = cast(Literal.create(Array(Array(1, 2, 3), Array(4, 5))), StringType) + checkEvaluation(ret7, "[[1, 2, 3], [4, 5]]") + val ret8 = cast( + Literal.create(Array(Array(Array("a"), Array("b", "c")), Array(Array("d")))), + StringType) + checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 5e077285ade55..96bf65fce9c4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -28,8 +28,6 @@ import org.apache.spark.scheduler.{SparkListener, SparkListenerJobStart} import org.apache.spark.sql.catalyst.util.StringUtils import org.apache.spark.sql.execution.aggregate import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} -import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation} -import org.apache.spark.sql.execution.datasources.orc.OrcFileFormat import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, CartesianProductExec, SortMergeJoinExec} import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf From 5b524cc0cd5a82e4fb0681363b6641e40b37075d Mon Sep 17 00:00:00 2001 From: Bago Amirbekian Date: Thu, 4 Jan 2018 22:45:15 -0800 Subject: [PATCH 34/55] [SPARK-22949][ML] Apply CrossValidator approach to Driver/Distributed memory tradeoff for TrainValidationSplit ## What changes were proposed in this pull request? Avoid holding all models in memory for `TrainValidationSplit`. ## How was this patch tested? Existing tests. Author: Bago Amirbekian Closes #20143 from MrBago/trainValidMemoryFix. (cherry picked from commit cf0aa65576acbe0209c67f04c029058fd73555c1) Signed-off-by: Joseph K. Bradley --- .../spark/ml/tuning/CrossValidator.scala | 4 +++- .../spark/ml/tuning/TrainValidationSplit.scala | 18 ++++-------------- 2 files changed, 7 insertions(+), 15 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 095b54c0fe83f..a0b507d2e718c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -160,8 +160,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String) } (executionContext) } - // Wait for metrics to be calculated before unpersisting validation dataset + // Wait for metrics to be calculated val foldMetrics = foldMetricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) + + // Unpersist training & validation set once all metrics have been produced trainingDataset.unpersist() validationDataset.unpersist() foldMetrics diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c73bd18475475..8826ef3271bc1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -143,24 +143,13 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Fit models in a Future for training in parallel logDebug(s"Train split with multiple sets of parameters.") - val modelFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => - Future[Model[_]] { + val metricFutures = epm.zipWithIndex.map { case (paramMap, paramIndex) => + Future[Double] { val model = est.fit(trainingDataset, paramMap).asInstanceOf[Model[_]] if (collectSubModelsParam) { subModels.get(paramIndex) = model } - model - } (executionContext) - } - - // Unpersist training data only when all models have trained - Future.sequence[Model[_], Iterable](modelFutures)(implicitly, executionContext) - .onComplete { _ => trainingDataset.unpersist() } (executionContext) - - // Evaluate models in a Future that will calulate a metric and allow model to be cleaned up - val metricFutures = modelFutures.zip(epm).map { case (modelFuture, paramMap) => - modelFuture.map { model => // TODO: duplicate evaluator to take extra params from input val metric = eval.evaluate(model.transform(validationDataset, paramMap)) logDebug(s"Got metric $metric for model trained with $paramMap.") @@ -171,7 +160,8 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St // Wait for all metrics to be calculated val metrics = metricFutures.map(ThreadUtils.awaitResult(_, Duration.Inf)) - // Unpersist validation set once all metrics have been produced + // Unpersist training & validation set once all metrics have been produced + trainingDataset.unpersist() validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") From f9dcdbcefb545ced3f5b457e1e88c88a8e180f9f Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Thu, 4 Jan 2018 23:23:41 -0800 Subject: [PATCH 35/55] [SPARK-22757][K8S] Enable spark.jars and spark.files in KUBERNETES mode ## What changes were proposed in this pull request? We missed enabling `spark.files` and `spark.jars` in https://github.com/apache/spark/pull/19954. The result is that remote dependencies specified through `spark.files` or `spark.jars` are not included in the list of remote dependencies to be downloaded by the init-container. This PR fixes it. ## How was this patch tested? Manual tests. vanzin This replaces https://github.com/apache/spark/pull/20157. foxish Author: Yinan Li Closes #20160 from liyinan926/SPARK-22757. (cherry picked from commit 6cff7d19f6a905fe425bd6892fe7ca014c0e696b) Signed-off-by: Felix Cheung --- .../src/main/scala/org/apache/spark/deploy/SparkSubmit.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala index cbe1f2c3e08a1..1e381965c52ba 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala @@ -584,10 +584,11 @@ object SparkSubmit extends CommandLineUtils with Logging { confKey = "spark.executor.memory"), OptionAssigner(args.totalExecutorCores, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.cores.max"), - OptionAssigner(args.files, LOCAL | STANDALONE | MESOS, ALL_DEPLOY_MODES, + OptionAssigner(args.files, LOCAL | STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, confKey = "spark.files"), OptionAssigner(args.jars, LOCAL, CLIENT, confKey = "spark.jars"), - OptionAssigner(args.jars, STANDALONE | MESOS, ALL_DEPLOY_MODES, confKey = "spark.jars"), + OptionAssigner(args.jars, STANDALONE | MESOS | KUBERNETES, ALL_DEPLOY_MODES, + confKey = "spark.jars"), OptionAssigner(args.driverMemory, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, confKey = "spark.driver.memory"), OptionAssigner(args.driverCores, STANDALONE | MESOS | YARN | KUBERNETES, CLUSTER, From fd4e30476894b7c37cc2ae6243a941f0bc90388d Mon Sep 17 00:00:00 2001 From: Adrian Ionescu Date: Fri, 5 Jan 2018 21:32:39 +0800 Subject: [PATCH 36/55] [SPARK-22961][REGRESSION] Constant columns should generate QueryPlanConstraints ## What changes were proposed in this pull request? #19201 introduced the following regression: given something like `df.withColumn("c", lit(2))`, we're no longer picking up `c === 2` as a constraint and infer filters from it when joins are involved, which may lead to noticeable performance degradation. This patch re-enables this optimization by picking up Aliases of Literals in Projection lists as constraints and making sure they're not treated as aliased columns. ## How was this patch tested? Unit test was added. Author: Adrian Ionescu Closes #20155 from adrian-ionescu/constant_constraints. (cherry picked from commit 51c33bd0d402af9e0284c6cbc0111f926446bfba) Signed-off-by: gatorsmile --- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 ++ .../plans/logical/QueryPlanConstraints.scala | 2 +- .../InferFiltersFromConstraintsSuite.scala | 13 +++++++++++++ 3 files changed, 16 insertions(+), 1 deletion(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index a38458add7b5e..ff2a0ec588567 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -247,6 +247,8 @@ abstract class UnaryNode extends LogicalPlan { protected def getAliasedConstraints(projectList: Seq[NamedExpression]): Set[Expression] = { var allConstraints = child.constraints.asInstanceOf[Set[Expression]] projectList.foreach { + case a @ Alias(l: Literal, _) => + allConstraints += EqualTo(a.toAttribute, l) case a @ Alias(e, _) => // For every alias in `projectList`, replace the reference in constraints by its attribute. allConstraints ++= allConstraints.map(_ transform { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala index b0f611fd38dea..9c0a30a47f839 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/QueryPlanConstraints.scala @@ -98,7 +98,7 @@ trait QueryPlanConstraints { self: LogicalPlan => // we may avoid producing recursive constraints. private lazy val aliasMap: AttributeMap[Expression] = AttributeMap( expressions.collect { - case a: Alias => (a.toAttribute, a.child) + case a: Alias if !a.child.isInstanceOf[Literal] => (a.toAttribute, a.child) } ++ children.flatMap(_.asInstanceOf[QueryPlanConstraints].aliasMap)) // Note: the explicit cast is necessary, since Scala compiler fails to infer the type. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala index 5580f8604ec72..a0708bf7eee9a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/InferFiltersFromConstraintsSuite.scala @@ -236,4 +236,17 @@ class InferFiltersFromConstraintsSuite extends PlanTest { comparePlans(optimized, originalQuery) } } + + test("constraints should be inferred from aliased literals") { + val originalLeft = testRelation.subquery('left).as("left") + val optimizedLeft = testRelation.subquery('left).where(IsNotNull('a) && 'a === 2).as("left") + + val right = Project(Seq(Literal(2).as("two")), testRelation.subquery('right)).as("right") + val condition = Some("left.a".attr === "right.two".attr) + + val original = originalLeft.join(right, Inner, condition) + val correct = optimizedLeft.join(right, Inner, condition) + + comparePlans(Optimize.execute(original.analyze), correct.analyze) + } } From 0a30e93507ba784729a498943e7eeda1d6f19fbf Mon Sep 17 00:00:00 2001 From: Bruce Robbins Date: Fri, 5 Jan 2018 09:58:28 -0800 Subject: [PATCH 37/55] [SPARK-22940][SQL] HiveExternalCatalogVersionsSuite should succeed on platforms that don't have wget ## What changes were proposed in this pull request? Modified HiveExternalCatalogVersionsSuite.scala to use Utils.doFetchFile to download different versions of Spark binaries rather than launching wget as an external process. On platforms that don't have wget installed, this suite fails with an error. cloud-fan : would you like to check this change? ## How was this patch tested? 1) test-only of HiveExternalCatalogVersionsSuite on several platforms. Tested bad mirror, read timeout, and redirects. 2) ./dev/run-tests Author: Bruce Robbins Closes #20147 from bersprockets/SPARK-22940-alt. (cherry picked from commit c0b7424ecacb56d3e7a18acc11ba3d5e7be57c43) Signed-off-by: Marcelo Vanzin --- .../HiveExternalCatalogVersionsSuite.scala | 48 ++++++++++++++++--- 1 file changed, 42 insertions(+), 6 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala index a3d5b941a6761..ae4aeb7b4ce4a 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveExternalCatalogVersionsSuite.scala @@ -18,11 +18,14 @@ package org.apache.spark.sql.hive import java.io.File -import java.nio.file.Files +import java.nio.charset.StandardCharsets +import java.nio.file.{Files, Paths} import scala.sys.process._ -import org.apache.spark.TestUtils +import org.apache.hadoop.conf.Configuration + +import org.apache.spark.{SecurityManager, SparkConf, TestUtils} import org.apache.spark.sql.{QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.catalog.CatalogTableType @@ -55,14 +58,19 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { private def tryDownloadSpark(version: String, path: String): Unit = { // Try mirrors a few times until one succeeds for (i <- 0 until 3) { + // we don't retry on a failure to get mirror url. If we can't get a mirror url, + // the test fails (getStringFromUrl will throw an exception) val preferredMirror = - Seq("wget", "https://www.apache.org/dyn/closer.lua?preferred=true", "-q", "-O", "-").!!.trim - val url = s"$preferredMirror/spark/spark-$version/spark-$version-bin-hadoop2.7.tgz" + getStringFromUrl("https://www.apache.org/dyn/closer.lua?preferred=true") + val filename = s"spark-$version-bin-hadoop2.7.tgz" + val url = s"$preferredMirror/spark/spark-$version/$filename" logInfo(s"Downloading Spark $version from $url") - if (Seq("wget", url, "-q", "-P", path).! == 0) { + try { + getFileFromUrl(url, path, filename) return + } catch { + case ex: Exception => logWarning(s"Failed to download Spark $version from $url", ex) } - logWarning(s"Failed to download Spark $version from $url") } fail(s"Unable to download Spark $version") } @@ -85,6 +93,34 @@ class HiveExternalCatalogVersionsSuite extends SparkSubmitTestUtils { new File(tmpDataDir, name).getCanonicalPath } + private def getFileFromUrl(urlString: String, targetDir: String, filename: String): Unit = { + val conf = new SparkConf + // if the caller passes the name of an existing file, we want doFetchFile to write over it with + // the contents from the specified url. + conf.set("spark.files.overwrite", "true") + val securityManager = new SecurityManager(conf) + val hadoopConf = new Configuration + + val outDir = new File(targetDir) + if (!outDir.exists()) { + outDir.mkdirs() + } + + // propagate exceptions up to the caller of getFileFromUrl + Utils.doFetchFile(urlString, outDir, filename, conf, securityManager, hadoopConf) + } + + private def getStringFromUrl(urlString: String): String = { + val contentFile = File.createTempFile("string-", ".txt") + contentFile.deleteOnExit() + + // exceptions will propagate to the caller of getStringFromUrl + getFileFromUrl(urlString, contentFile.getParent, contentFile.getName) + + val contentPath = Paths.get(contentFile.toURI) + new String(Files.readAllBytes(contentPath), StandardCharsets.UTF_8) + } + override def beforeAll(): Unit = { super.beforeAll() From d1f422c1c12c8095e8522d1051a6e0e406748a3a Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Fri, 5 Jan 2018 11:51:25 -0800 Subject: [PATCH 38/55] [SPARK-13030][ML] Follow-up cleanups for OneHotEncoderEstimator ## What changes were proposed in this pull request? Follow-up cleanups for the OneHotEncoderEstimator PR. See some discussion in the original PR: https://github.com/apache/spark/pull/19527 or read below for what this PR includes: * configedCategorySize: I reverted this to return an Array. I realized the original setup (which I had recommended in the original PR) caused the whole model to be serialized in the UDF. * encoder: I reorganized the logic to show what I meant in the comment in the previous PR. I think it's simpler but am open to suggestions. I also made some small style cleanups based on IntelliJ warnings. ## How was this patch tested? Existing unit tests Author: Joseph K. Bradley Closes #20132 from jkbradley/viirya-SPARK-13030. (cherry picked from commit 930b90a84871e2504b57ed50efa7b8bb52d3ba44) Signed-off-by: Joseph K. Bradley --- .../ml/feature/OneHotEncoderEstimator.scala | 92 ++++++++++--------- 1 file changed, 49 insertions(+), 43 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala index 074622d41e28d..bd1e3426c8780 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/OneHotEncoderEstimator.scala @@ -30,24 +30,27 @@ import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.expressions.UserDefinedFunction import org.apache.spark.sql.functions.{col, lit, udf} -import org.apache.spark.sql.types.{DoubleType, NumericType, StructField, StructType} +import org.apache.spark.sql.types.{DoubleType, StructField, StructType} /** Private trait for params and common methods for OneHotEncoderEstimator and OneHotEncoderModel */ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid with HasInputCols with HasOutputCols { /** - * Param for how to handle invalid data. + * Param for how to handle invalid data during transform(). * Options are 'keep' (invalid data presented as an extra categorical feature) or * 'error' (throw an error). + * Note that this Param is only used during transform; during fitting, invalid data + * will result in an error. * Default: "error" * @group param */ @Since("2.3.0") override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", - "How to handle invalid data " + + "How to handle invalid data during transform(). " + "Options are 'keep' (invalid data presented as an extra categorical feature) " + - "or error (throw an error).", + "or error (throw an error). Note that this Param is only used during transform; " + + "during fitting, invalid data will result in an error.", ParamValidators.inArray(OneHotEncoderEstimator.supportedHandleInvalids)) setDefault(handleInvalid, OneHotEncoderEstimator.ERROR_INVALID) @@ -66,10 +69,11 @@ private[ml] trait OneHotEncoderBase extends Params with HasHandleInvalid def getDropLast: Boolean = $(dropLast) protected def validateAndTransformSchema( - schema: StructType, dropLast: Boolean, keepInvalid: Boolean): StructType = { + schema: StructType, + dropLast: Boolean, + keepInvalid: Boolean): StructType = { val inputColNames = $(inputCols) val outputColNames = $(outputCols) - val existingFields = schema.fields require(inputColNames.length == outputColNames.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -197,6 +201,10 @@ object OneHotEncoderEstimator extends DefaultParamsReadable[OneHotEncoderEstimat override def load(path: String): OneHotEncoderEstimator = super.load(path) } +/** + * @param categorySizes Original number of categories for each feature being encoded. + * The array contains one value for each input column, in order. + */ @Since("2.3.0") class OneHotEncoderModel private[ml] ( @Since("2.3.0") override val uid: String, @@ -205,60 +213,58 @@ class OneHotEncoderModel private[ml] ( import OneHotEncoderModel._ - // Returns the category size for a given index with `dropLast` and `handleInvalid` + // Returns the category size for each index with `dropLast` and `handleInvalid` // taken into account. - private def configedCategorySize(orgCategorySize: Int, idx: Int): Int = { + private def getConfigedCategorySizes: Array[Int] = { val dropLast = getDropLast val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID if (!dropLast && keepInvalid) { // When `handleInvalid` is "keep", an extra category is added as last category // for invalid data. - orgCategorySize + 1 + categorySizes.map(_ + 1) } else if (dropLast && !keepInvalid) { // When `dropLast` is true, the last category is removed. - orgCategorySize - 1 + categorySizes.map(_ - 1) } else { // When `dropLast` is true and `handleInvalid` is "keep", the extra category for invalid // data is removed. Thus, it is the same as the plain number of categories. - orgCategorySize + categorySizes } } private def encoder: UserDefinedFunction = { - val oneValue = Array(1.0) - val emptyValues = Array.empty[Double] - val emptyIndices = Array.empty[Int] - val dropLast = getDropLast - val handleInvalid = getHandleInvalid - val keepInvalid = handleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val keepInvalid = getHandleInvalid == OneHotEncoderEstimator.KEEP_INVALID + val configedSizes = getConfigedCategorySizes + val localCategorySizes = categorySizes // The udf performed on input data. The first parameter is the input value. The second - // parameter is the index of input. - udf { (label: Double, idx: Int) => - val plainNumCategories = categorySizes(idx) - val size = configedCategorySize(plainNumCategories, idx) - - if (label < 0) { - throw new SparkException(s"Negative value: $label. Input can't be negative.") - } else if (label == size && dropLast && !keepInvalid) { - // When `dropLast` is true and `handleInvalid` is not "keep", - // the last category is removed. - Vectors.sparse(size, emptyIndices, emptyValues) - } else if (label >= plainNumCategories && keepInvalid) { - // When `handleInvalid` is "keep", encodes invalid data to last category (and removed - // if `dropLast` is true) - if (dropLast) { - Vectors.sparse(size, emptyIndices, emptyValues) + // parameter is the index in inputCols of the column being encoded. + udf { (label: Double, colIdx: Int) => + val origCategorySize = localCategorySizes(colIdx) + // idx: index in vector of the single 1-valued element + val idx = if (label >= 0 && label < origCategorySize) { + label + } else { + if (keepInvalid) { + origCategorySize } else { - Vectors.sparse(size, Array(size - 1), oneValue) + if (label < 0) { + throw new SparkException(s"Negative value: $label. Input can't be negative. " + + s"To handle invalid values, set Param handleInvalid to " + + s"${OneHotEncoderEstimator.KEEP_INVALID}") + } else { + throw new SparkException(s"Unseen value: $label. To handle unseen values, " + + s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + } } - } else if (label < plainNumCategories) { - Vectors.sparse(size, Array(label.toInt), oneValue) + } + + val size = configedSizes(colIdx) + if (idx < size) { + Vectors.sparse(size, Array(idx.toInt), Array(1.0)) } else { - assert(handleInvalid == OneHotEncoderEstimator.ERROR_INVALID) - throw new SparkException(s"Unseen value: $label. To handle unseen values, " + - s"set Param handleInvalid to ${OneHotEncoderEstimator.KEEP_INVALID}.") + Vectors.sparse(size, Array.empty[Int], Array.empty[Double]) } } } @@ -282,7 +288,6 @@ class OneHotEncoderModel private[ml] ( @Since("2.3.0") override def transformSchema(schema: StructType): StructType = { val inputColNames = $(inputCols) - val outputColNames = $(outputCols) require(inputColNames.length == categorySizes.length, s"The number of input columns ${inputColNames.length} must be the same as the number of " + @@ -300,6 +305,7 @@ class OneHotEncoderModel private[ml] ( * account. Mismatched numbers will cause exception. */ private def verifyNumOfValues(schema: StructType): StructType = { + val configedSizes = getConfigedCategorySizes $(outputCols).zipWithIndex.foreach { case (outputColName, idx) => val inputColName = $(inputCols)(idx) val attrGroup = AttributeGroup.fromStructField(schema(outputColName)) @@ -308,9 +314,9 @@ class OneHotEncoderModel private[ml] ( // comparing with expected category number with `handleInvalid` and // `dropLast` taken into account. if (attrGroup.attributes.nonEmpty) { - val numCategories = configedCategorySize(categorySizes(idx), idx) + val numCategories = configedSizes(idx) require(attrGroup.size == numCategories, "OneHotEncoderModel expected " + - s"$numCategories categorical values for input column ${inputColName}, " + + s"$numCategories categorical values for input column $inputColName, " + s"but the input column had metadata specifying ${attrGroup.size} values.") } } @@ -322,7 +328,7 @@ class OneHotEncoderModel private[ml] ( val transformedSchema = transformSchema(dataset.schema, logging = true) val keepInvalid = $(handleInvalid) == OneHotEncoderEstimator.KEEP_INVALID - val encodedColumns = (0 until $(inputCols).length).map { idx => + val encodedColumns = $(inputCols).indices.map { idx => val inputColName = $(inputCols)(idx) val outputColName = $(outputCols)(idx) From 55afac4e7b4f655aa05c5bcaf7851bb1e7699dba Mon Sep 17 00:00:00 2001 From: Gera Shegalov Date: Fri, 5 Jan 2018 17:25:28 -0800 Subject: [PATCH 39/55] [SPARK-22914][DEPLOY] Register history.ui.port ## What changes were proposed in this pull request? Register spark.history.ui.port as a known spark conf to be used in substitution expressions even if it's not set explicitly. ## How was this patch tested? Added unit test to demonstrate the issue Author: Gera Shegalov Author: Gera Shegalov Closes #20098 from gerashegalov/gera/register-SHS-port-conf. (cherry picked from commit ea956833017fcbd8ed2288368bfa2e417a2251c5) Signed-off-by: Marcelo Vanzin --- .../spark/deploy/history/HistoryServer.scala | 3 +- .../apache/spark/deploy/history/config.scala | 5 +++ .../spark/deploy/yarn/ApplicationMaster.scala | 17 +++++--- .../deploy/yarn/ApplicationMasterSuite.scala | 43 +++++++++++++++++++ 4 files changed, 62 insertions(+), 6 deletions(-) create mode 100644 resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala index 75484f5c9f30f..0ec4afad0308c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala @@ -28,6 +28,7 @@ import org.eclipse.jetty.servlet.{ServletContextHandler, ServletHolder} import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.deploy.history.config.HISTORY_SERVER_UI_PORT import org.apache.spark.internal.Logging import org.apache.spark.internal.config._ import org.apache.spark.status.api.v1.{ApiRootResource, ApplicationInfo, UIRoot} @@ -276,7 +277,7 @@ object HistoryServer extends Logging { .newInstance(conf) .asInstanceOf[ApplicationHistoryProvider] - val port = conf.getInt("spark.history.ui.port", 18080) + val port = conf.get(HISTORY_SERVER_UI_PORT) val server = new HistoryServer(conf, provider, securityManager, port) server.bind() diff --git a/core/src/main/scala/org/apache/spark/deploy/history/config.scala b/core/src/main/scala/org/apache/spark/deploy/history/config.scala index 22b6d49d8e2a4..efdbf672bb52f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/history/config.scala +++ b/core/src/main/scala/org/apache/spark/deploy/history/config.scala @@ -44,4 +44,9 @@ private[spark] object config { .bytesConf(ByteUnit.BYTE) .createWithDefaultString("10g") + val HISTORY_SERVER_UI_PORT = ConfigBuilder("spark.history.ui.port") + .doc("Web UI port to bind Spark History Server") + .intConf + .createWithDefault(18080) + } diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index b2576b0d72633..4d5e3bb043671 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -427,11 +427,8 @@ private[spark] class ApplicationMaster(args: ApplicationMasterArguments) extends uiAddress: Option[String]) = { val appId = client.getAttemptId().getApplicationId().toString() val attemptId = client.getAttemptId().getAttemptId().toString() - val historyAddress = - _sparkConf.get(HISTORY_SERVER_ADDRESS) - .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } - .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } - .getOrElse("") + val historyAddress = ApplicationMaster + .getHistoryServerAddress(_sparkConf, yarnConf, appId, attemptId) val driverUrl = RpcEndpointAddress( _sparkConf.get("spark.driver.host"), @@ -834,6 +831,16 @@ object ApplicationMaster extends Logging { master.getAttemptId } + private[spark] def getHistoryServerAddress( + sparkConf: SparkConf, + yarnConf: YarnConfiguration, + appId: String, + attemptId: String): String = { + sparkConf.get(HISTORY_SERVER_ADDRESS) + .map { text => SparkHadoopUtil.get.substituteHadoopVariables(text, yarnConf) } + .map { address => s"${address}${HistoryServer.UI_PATH_PREFIX}/${appId}/${attemptId}" } + .getOrElse("") + } } /** diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala new file mode 100644 index 0000000000000..695a82f3583e6 --- /dev/null +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/ApplicationMasterSuite.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import org.apache.hadoop.yarn.conf.YarnConfiguration + +import org.apache.spark.{SparkConf, SparkFunSuite} + +class ApplicationMasterSuite extends SparkFunSuite { + + test("history url with hadoop and spark substitutions") { + val host = "rm.host.com" + val port = 18080 + val sparkConf = new SparkConf() + + sparkConf.set("spark.yarn.historyServer.address", + "http://${hadoopconf-yarn.resourcemanager.hostname}:${spark.history.ui.port}") + val yarnConf = new YarnConfiguration() + yarnConf.set("yarn.resourcemanager.hostname", host) + val appId = "application_123_1" + val attemptId = appId + "_1" + + val shsAddr = ApplicationMaster + .getHistoryServerAddress(sparkConf, yarnConf, appId, attemptId) + + assert(shsAddr === s"http://${host}:${port}/history/${appId}/${attemptId}") + } +} From bf853018cabcd3b3abf84bfe534d2981020b4a71 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sat, 6 Jan 2018 09:26:03 +0800 Subject: [PATCH 40/55] [SPARK-22937][SQL] SQL elt output binary for binary inputs ## What changes were proposed in this pull request? This pr modified `elt` to output binary for binary inputs. `elt` in the current master always output data as a string. But, in some databases (e.g., MySQL), if all inputs are binary, `elt` also outputs binary (Also, this might be a small surprise). This pr is related to #19977. ## How was this patch tested? Added tests in `SQLQueryTestSuite` and `TypeCoercionSuite`. Author: Takeshi Yamamuro Closes #20135 from maropu/SPARK-22937. (cherry picked from commit e8af7e8aeca15a6107248f358d9514521ffdc6d3) Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 2 + .../sql/catalyst/analysis/TypeCoercion.scala | 29 +++++ .../expressions/stringExpressions.scala | 46 ++++--- .../apache/spark/sql/internal/SQLConf.scala | 8 ++ .../catalyst/analysis/TypeCoercionSuite.scala | 54 ++++++++ .../inputs/typeCoercion/native/elt.sql | 44 +++++++ .../results/typeCoercion/native/elt.sql.out | 115 ++++++++++++++++++ 7 files changed, 281 insertions(+), 17 deletions(-) create mode 100644 sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql create mode 100644 sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index dc3e384008d27..b50f9360b866c 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -1783,6 +1783,8 @@ options. - Since Spark 2.3, when all inputs are binary, `functions.concat()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.concatBinaryAsString` to `true`. + - Since Spark 2.3, when all inputs are binary, SQL `elt()` returns an output as binary. Otherwise, it returns as a string. Until Spark 2.3, it always returns as a string despite of input types. To keep the old behavior, set `spark.sql.function.eltOutputAsString` to `true`. + ## Upgrading From Spark SQL 2.1 to 2.2 - Spark 2.1.1 introduced a new configuration key: `spark.sql.hive.caseSensitiveInferenceMode`. It had a default setting of `NEVER_INFER`, which kept behavior identical to 2.1.0. However, Spark 2.2.0 changes this setting's default value to `INFER_AND_SAVE` to restore compatibility with reading Hive metastore tables whose underlying file schema have mixed-case column names. With the `INFER_AND_SAVE` configuration value, on first access Spark will perform schema inference on any Hive metastore table for which it has not already saved an inferred schema. Note that schema inference can be a very time consuming operation for tables with thousands of partitions. If compatibility with mixed-case column names is not a concern, you can safely set `spark.sql.hive.caseSensitiveInferenceMode` to `NEVER_INFER` to avoid the initial overhead of schema inference. Note that with the new default `INFER_AND_SAVE` setting, the results of the schema inference are saved as a metastore key for future use. Therefore, the initial schema inference occurs only at a table's first access. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala index e9436367c7e2e..e8669c4637d06 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercion.scala @@ -54,6 +54,7 @@ object TypeCoercion { BooleanEquality :: FunctionArgumentConversion :: ConcatCoercion(conf) :: + EltCoercion(conf) :: CaseWhenCoercion :: IfCoercion :: StackCoercion :: @@ -684,6 +685,34 @@ object TypeCoercion { } } + /** + * Coerces the types of [[Elt]] children to expected ones. + * + * If `spark.sql.function.eltOutputAsString` is false and all children types are binary, + * the expected types are binary. Otherwise, the expected ones are strings. + */ + case class EltCoercion(conf: SQLConf) extends TypeCoercionRule { + + override protected def coerceTypes(plan: LogicalPlan): LogicalPlan = plan transform { case p => + p transformExpressionsUp { + // Skip nodes if unresolved or not enough children + case c @ Elt(children) if !c.childrenResolved || children.size < 2 => c + case c @ Elt(children) => + val index = children.head + val newIndex = ImplicitTypeCasts.implicitCast(index, IntegerType).getOrElse(index) + val newInputs = if (conf.eltOutputAsString || + !children.tail.map(_.dataType).forall(_ == BinaryType)) { + children.tail.map { e => + ImplicitTypeCasts.implicitCast(e, StringType).getOrElse(e) + } + } else { + children.tail + } + c.copy(children = newIndex +: newInputs) + } + } + } + /** * Turns Add/Subtract of DateType/TimestampType/StringType and CalendarIntervalType * to TimeAdd/TimeSub diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala index 41dc762154a4c..e004bfc6af473 100755 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala @@ -271,33 +271,45 @@ case class ConcatWs(children: Seq[Expression]) } } +/** + * An expression that returns the `n`-th input in given inputs. + * If all inputs are binary, `elt` returns an output as binary. Otherwise, it returns as string. + * If any input is null, `elt` returns null. + */ // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(n, str1, str2, ...) - Returns the `n`-th string, e.g., returns `str2` when `n` is 2.", + usage = "_FUNC_(n, input1, input2, ...) - Returns the `n`-th input, e.g., returns `input2` when `n` is 2.", examples = """ Examples: > SELECT _FUNC_(1, 'scala', 'java'); scala """) // scalastyle:on line.size.limit -case class Elt(children: Seq[Expression]) - extends Expression with ImplicitCastInputTypes { +case class Elt(children: Seq[Expression]) extends Expression { private lazy val indexExpr = children.head - private lazy val stringExprs = children.tail.toArray + private lazy val inputExprs = children.tail.toArray /** This expression is always nullable because it returns null if index is out of range. */ override def nullable: Boolean = true - override def dataType: DataType = StringType - - override def inputTypes: Seq[DataType] = IntegerType +: Seq.fill(children.size - 1)(StringType) + override def dataType: DataType = inputExprs.map(_.dataType).headOption.getOrElse(StringType) override def checkInputDataTypes(): TypeCheckResult = { if (children.size < 2) { TypeCheckResult.TypeCheckFailure("elt function requires at least two arguments") } else { - super[ImplicitCastInputTypes].checkInputDataTypes() + val (indexType, inputTypes) = (indexExpr.dataType, inputExprs.map(_.dataType)) + if (indexType != IntegerType) { + return TypeCheckResult.TypeCheckFailure(s"first input to function $prettyName should " + + s"have IntegerType, but it's $indexType") + } + if (inputTypes.exists(tpe => !Seq(StringType, BinaryType).contains(tpe))) { + return TypeCheckResult.TypeCheckFailure( + s"input to function $prettyName should have StringType or BinaryType, but it's " + + inputTypes.map(_.simpleString).mkString("[", ", ", "]")) + } + TypeUtils.checkForSameTypeInputExpr(inputTypes, s"function $prettyName") } } @@ -307,27 +319,27 @@ case class Elt(children: Seq[Expression]) null } else { val index = indexObj.asInstanceOf[Int] - if (index <= 0 || index > stringExprs.length) { + if (index <= 0 || index > inputExprs.length) { null } else { - stringExprs(index - 1).eval(input) + inputExprs(index - 1).eval(input) } } } override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { val index = indexExpr.genCode(ctx) - val strings = stringExprs.map(_.genCode(ctx)) + val inputs = inputExprs.map(_.genCode(ctx)) val indexVal = ctx.freshName("index") val indexMatched = ctx.freshName("eltIndexMatched") - val stringVal = ctx.addMutableState(ctx.javaType(dataType), "stringVal") + val inputVal = ctx.addMutableState(ctx.javaType(dataType), "inputVal") - val assignStringValue = strings.zipWithIndex.map { case (eval, index) => + val assignInputValue = inputs.zipWithIndex.map { case (eval, index) => s""" |if ($indexVal == ${index + 1}) { | ${eval.code} - | $stringVal = ${eval.isNull} ? null : ${eval.value}; + | $inputVal = ${eval.isNull} ? null : ${eval.value}; | $indexMatched = true; | continue; |} @@ -335,7 +347,7 @@ case class Elt(children: Seq[Expression]) } val codes = ctx.splitExpressionsWithCurrentInputs( - expressions = assignStringValue, + expressions = assignInputValue, funcName = "eltFunc", extraArguments = ("int", indexVal) :: Nil, returnType = ctx.JAVA_BOOLEAN, @@ -361,11 +373,11 @@ case class Elt(children: Seq[Expression]) |${index.code} |final int $indexVal = ${index.value}; |${ctx.JAVA_BOOLEAN} $indexMatched = false; - |$stringVal = null; + |$inputVal = null; |do { | $codes |} while (false); - |final UTF8String ${ev.value} = $stringVal; + |final ${ctx.javaType(dataType)} ${ev.value} = $inputVal; |final boolean ${ev.isNull} = ${ev.value} == null; """.stripMargin) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 5d6edf6b8abec..80b8965e084a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1052,6 +1052,12 @@ object SQLConf { .booleanConf .createWithDefault(false) + val ELT_OUTPUT_AS_STRING = buildConf("spark.sql.function.eltOutputAsString") + .doc("When this option is set to false and all inputs are binary, `elt` returns " + + "an output as binary. Otherwise, it returns as a string. ") + .booleanConf + .createWithDefault(false) + val CONTINUOUS_STREAMING_EXECUTOR_QUEUE_SIZE = buildConf("spark.sql.streaming.continuous.executorQueueSize") .internal() @@ -1412,6 +1418,8 @@ class SQLConf extends Serializable with Logging { def concatBinaryAsString: Boolean = getConf(CONCAT_BINARY_AS_STRING) + def eltOutputAsString: Boolean = getConf(ELT_OUTPUT_AS_STRING) + def partitionOverwriteMode: PartitionOverwriteMode.Value = PartitionOverwriteMode.withName(getConf(PARTITION_OVERWRITE_MODE)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala index 3661530cd622b..52a7ebdafd7c7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/TypeCoercionSuite.scala @@ -923,6 +923,60 @@ class TypeCoercionSuite extends AnalysisTest { } } + test("type coercion for Elt") { + val rule = TypeCoercion.EltCoercion(conf) + + ruleTest(rule, + Elt(Seq(Literal(1), Literal("ab"), Literal("cde"))), + Elt(Seq(Literal(1), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(1.toShort), Literal("ab"), Literal("cde"))), + Elt(Seq(Cast(Literal(1.toShort), IntegerType), Literal("ab"), Literal("cde")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(null), Literal("abc"))), + Elt(Seq(Literal(2), Cast(Literal(null), StringType), Literal("abc")))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(1), Literal("234"))), + Elt(Seq(Literal(2), Cast(Literal(1), StringType), Literal("234")))) + ruleTest(rule, + Elt(Seq(Literal(3), Literal(1L), Literal(2.toByte), Literal(0.1))), + Elt(Seq(Literal(3), Cast(Literal(1L), StringType), Cast(Literal(2.toByte), StringType), + Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(true), Literal(0.1f), Literal(3.toShort))), + Elt(Seq(Literal(2), Cast(Literal(true), StringType), Cast(Literal(0.1f), StringType), + Cast(Literal(3.toShort), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(1L), Literal(0.1))), + Elt(Seq(Literal(1), Cast(Literal(1L), StringType), Cast(Literal(0.1), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(Decimal(10)))), + Elt(Seq(Literal(1), Cast(Literal(Decimal(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(1), Literal(java.math.BigDecimal.valueOf(10)))), + Elt(Seq(Literal(1), Cast(Literal(java.math.BigDecimal.valueOf(10)), StringType)))) + ruleTest(rule, + Elt(Seq(Literal(2), Literal(new java.sql.Date(0)), Literal(new Timestamp(0)))), + Elt(Seq(Literal(2), Cast(Literal(new java.sql.Date(0)), StringType), + Cast(Literal(new Timestamp(0)), StringType)))) + + withSQLConf("spark.sql.function.eltOutputAsString" -> "true") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Cast(Literal("123".getBytes), StringType), + Cast(Literal("456".getBytes), StringType)))) + } + + withSQLConf("spark.sql.function.eltOutputAsString" -> "false") { + ruleTest(rule, + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes))), + Elt(Seq(Literal(1), Literal("123".getBytes), Literal("456".getBytes)))) + } + } + test("BooleanEquality type cast") { val be = TypeCoercion.BooleanEquality // Use something more than a literal to avoid triggering the simplification rules. diff --git a/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql new file mode 100644 index 0000000000000..717616f91db05 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/typeCoercion/native/elt.sql @@ -0,0 +1,44 @@ +-- Mixed inputs (output type is string) +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +); + +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +); + +-- turn on eltOutputAsString +set spark.sql.function.eltOutputAsString=true; + +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); + +-- turn off eltOutputAsString +set spark.sql.function.eltOutputAsString=false; + +-- Elt binary inputs (output type is binary) +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +); diff --git a/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out new file mode 100644 index 0000000000000..b62e1b6826045 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/typeCoercion/native/elt.sql.out @@ -0,0 +1,115 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +SELECT elt(2, col1, col2, col3, col4, col5) col +FROM ( + SELECT + 'prefix_' col1, + id col2, + string(id + 1) col3, + encode(string(id + 2), 'utf-8') col4, + CAST(id AS DOUBLE) col5 + FROM range(10) +) +-- !query 0 schema +struct +-- !query 0 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 1 +SELECT elt(3, col1, col2, col3, col4) col +FROM ( + SELECT + string(id) col1, + string(id + 1) col2, + encode(string(id + 2), 'utf-8') col3, + encode(string(id + 3), 'utf-8') col4 + FROM range(10) +) +-- !query 1 schema +struct +-- !query 1 output +10 +11 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 2 +set spark.sql.function.eltOutputAsString=true +-- !query 2 schema +struct +-- !query 2 output +spark.sql.function.eltOutputAsString true + + +-- !query 3 +SELECT elt(1, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 3 schema +struct +-- !query 3 output +0 +1 +2 +3 +4 +5 +6 +7 +8 +9 + + +-- !query 4 +set spark.sql.function.eltOutputAsString=false +-- !query 4 schema +struct +-- !query 4 output +spark.sql.function.eltOutputAsString false + + +-- !query 5 +SELECT elt(2, col1, col2) col +FROM ( + SELECT + encode(string(id), 'utf-8') col1, + encode(string(id + 1), 'utf-8') col2 + FROM range(10) +) +-- !query 5 schema +struct +-- !query 5 output +1 +10 +2 +3 +4 +5 +6 +7 +8 +9 From 3e3e9386ed95435a2d1817653d1402c102e380dc Mon Sep 17 00:00:00 2001 From: Yinan Li Date: Fri, 5 Jan 2018 17:29:27 -0800 Subject: [PATCH 41/55] [SPARK-22960][K8S] Revert use of ARG base_image in images ## What changes were proposed in this pull request? This PR reverts the `ARG base_image` before `FROM` in the images of driver, executor, and init-container, introduced in https://github.com/apache/spark/pull/20154. The reason is Docker versions before 17.06 do not support this use (`ARG` before `FROM`). ## How was this patch tested? Tested manually. vanzin foxish kimoonkim Author: Yinan Li Closes #20170 from liyinan926/master. (cherry picked from commit bf65cd3cda46d5480bfcd13110975c46ca631972) Signed-off-by: Marcelo Vanzin --- .../docker/src/main/dockerfiles/driver/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/executor/Dockerfile | 3 +-- .../docker/src/main/dockerfiles/init-container/Dockerfile | 3 +-- sbin/build-push-docker-images.sh | 8 ++++---- 4 files changed, 7 insertions(+), 10 deletions(-) diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile index ff5289e10c21e..45fbcd9cd0deb 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/driver/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile index 3eabb42d4d852..0f806cf7e148e 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/executor/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # Before building the docker image, first build and make a Spark distribution following # the instructions in http://spark.apache.org/docs/latest/building-spark.html. diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile index e0a249e0ac71f..047056ab2633b 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/init-container/Dockerfile @@ -15,8 +15,7 @@ # limitations under the License. # -ARG base_image -FROM ${base_image} +FROM spark-base # If this docker file is being used in the context of building your images from a Spark distribution, the docker build # command should be invoked from the top level directory of the Spark distribution. E.g.: diff --git a/sbin/build-push-docker-images.sh b/sbin/build-push-docker-images.sh index bb8806dd33f37..b9532597419a5 100755 --- a/sbin/build-push-docker-images.sh +++ b/sbin/build-push-docker-images.sh @@ -60,13 +60,13 @@ function image_ref { } function build { - local base_image="$(image_ref spark-base 0)" - docker build --build-arg "spark_jars=$SPARK_JARS" \ + docker build \ + --build-arg "spark_jars=$SPARK_JARS" \ --build-arg "img_path=$IMG_PATH" \ - -t "$base_image" \ + -t spark-base \ -f "$IMG_PATH/spark-base/Dockerfile" . for image in "${!path[@]}"; do - docker build --build-arg "base_image=$base_image" -t "$(image_ref $image)" -f ${path[$image]} . + docker build -t "$(image_ref $image)" -f ${path[$image]} . done } From 7236914e5e7aeb4eb919530b6edbad70256cca52 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Sat, 6 Jan 2018 16:11:20 +0800 Subject: [PATCH 42/55] [SPARK-22930][PYTHON][SQL] Improve the description of Vectorized UDFs for non-deterministic cases ## What changes were proposed in this pull request? Add tests for using non deterministic UDFs in aggregate. Update pandas_udf docstring w.r.t to determinism. ## How was this patch tested? test_nondeterministic_udf_in_aggregate Author: Li Jin Closes #20142 from icexelloss/SPARK-22930-pandas-udf-deterministic. (cherry picked from commit f2dd8b923759e8771b0e5f59bfa7ae4ad7e6a339) Signed-off-by: gatorsmile --- python/pyspark/sql/functions.py | 12 +++++++- python/pyspark/sql/tests.py | 52 +++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+), 1 deletion(-) diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py index a4ed562ad48b4..733e32bd825b0 100644 --- a/python/pyspark/sql/functions.py +++ b/python/pyspark/sql/functions.py @@ -2214,7 +2214,17 @@ def pandas_udf(f=None, returnType=None, functionType=None): .. seealso:: :meth:`pyspark.sql.GroupedData.apply` - .. note:: The user-defined function must be deterministic. + .. note:: The user-defined functions are considered deterministic by default. Due to + optimization, duplicate invocations may be eliminated or the function may even be invoked + more times than it is present in the query. If your function is not deterministic, call + `asNondeterministic` on the user defined function. E.g.: + + >>> @pandas_udf('double', PandasUDFType.SCALAR) # doctest: +SKIP + ... def random(v): + ... import numpy as np + ... import pandas as pd + ... return pd.Series(np.random.randn(len(v)) + >>> random = random.asNondeterministic() # doctest: +SKIP .. note:: The user-defined functions do not support conditional expressions or short curcuiting in boolean expressions and it ends up with being executed all internally. If the functions diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 6dc767f9ec46e..689736d8e6456 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -386,6 +386,7 @@ def test_udf3(self): self.assertEqual(row[0], 5) def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations from pyspark.sql.functions import udf import random udf_random_col = udf(lambda: int(100 * random.random()), IntegerType()).asNondeterministic() @@ -413,6 +414,18 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import udf, sum + import random + udf_random_col = udf(lambda: int(100 * random.random()), 'int').asNondeterministic() + df = self.spark.range(10) + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.groupby('id').agg(sum(udf_random_col())).collect() + with self.assertRaisesRegexp(AnalysisException, "nondeterministic"): + df.agg(sum(udf_random_col())).collect() + def test_chained_udf(self): self.spark.catalog.registerFunction("double", lambda x: x + x, IntegerType()) [row] = self.spark.sql("SELECT double(1)").collect() @@ -3567,6 +3580,18 @@ def tearDownClass(cls): time.tzset() ReusedSQLTestCase.tearDownClass() + @property + def random_udf(self): + from pyspark.sql.functions import pandas_udf + + @pandas_udf('double') + def random_udf(v): + import pandas as pd + import numpy as np + return pd.Series(np.random.random(len(v))) + random_udf = random_udf.asNondeterministic() + return random_udf + def test_vectorized_udf_basic(self): from pyspark.sql.functions import pandas_udf, col df = self.spark.range(10).select( @@ -3950,6 +3975,33 @@ def test_vectorized_udf_timestamps_respect_session_timezone(self): finally: self.spark.conf.set("spark.sql.session.timeZone", orig_tz) + def test_nondeterministic_udf(self): + # Test that nondeterministic UDFs are evaluated only once in chained UDF evaluations + from pyspark.sql.functions import udf, pandas_udf, col + + @pandas_udf('double') + def plus_ten(v): + return v + 10 + random_udf = self.random_udf + + df = self.spark.range(10).withColumn('rand', random_udf(col('id'))) + result1 = df.withColumn('plus_ten(rand)', plus_ten(df['rand'])).toPandas() + + self.assertEqual(random_udf.deterministic, False) + self.assertTrue(result1['plus_ten(rand)'].equals(result1['rand'] + 10)) + + def test_nondeterministic_udf_in_aggregate(self): + from pyspark.sql.functions import pandas_udf, sum + + df = self.spark.range(10) + random_udf = self.random_udf + + with QuietTest(self.sc): + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.groupby(df.id).agg(sum(random_udf(df.id))).collect() + with self.assertRaisesRegexp(AnalysisException, 'nondeterministic'): + df.agg(sum(random_udf(df.id))).collect() + @unittest.skipIf(not _have_pandas or not _have_arrow, "Pandas or Arrow not installed") class GroupbyApplyTests(ReusedSQLTestCase): From e6449e8167776e3921c286d75e8cdd30ee33d77a Mon Sep 17 00:00:00 2001 From: zuotingbing Date: Sat, 6 Jan 2018 18:07:45 +0800 Subject: [PATCH 43/55] [SPARK-22793][SQL] Memory leak in Spark Thrift Server # What changes were proposed in this pull request? 1. Start HiveThriftServer2. 2. Connect to thriftserver through beeline. 3. Close the beeline. 4. repeat step2 and step 3 for many times. we found there are many directories never be dropped under the path `hive.exec.local.scratchdir` and `hive.exec.scratchdir`, as we know the scratchdir has been added to deleteOnExit when it be created. So it means that the cache size of FileSystem `deleteOnExit` will keep increasing until JVM terminated. In addition, we use `jmap -histo:live [PID]` to printout the size of objects in HiveThriftServer2 Process, we can find the object `org.apache.spark.sql.hive.client.HiveClientImpl` and `org.apache.hadoop.hive.ql.session.SessionState` keep increasing even though we closed all the beeline connections, which may caused the leak of Memory. # How was this patch tested? manual tests This PR follw-up the https://github.com/apache/spark/pull/19989 Author: zuotingbing Closes #20029 from zuotingbing/SPARK-22793. (cherry picked from commit be9a804f2ef77a5044d3da7d9374976daf59fc16) Signed-off-by: gatorsmile --- .../org/apache/spark/sql/hive/HiveSessionStateBuilder.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 92cb4ef11c9e3..dc92ad3b0c1ac 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -42,7 +42,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session * Create a Hive aware resource loader. */ override protected lazy val resourceLoader: HiveSessionResourceLoader = { - val client: HiveClient = externalCatalog.client.newSession() + val client: HiveClient = externalCatalog.client new HiveSessionResourceLoader(session, client) } From 0377755985897c850dcf3a938520e5c10de4040a Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 6 Jan 2018 18:19:57 +0800 Subject: [PATCH 44/55] [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. [SPARK-21786][SQL] When acquiring 'compressionCodecClassName' in 'ParquetOptions', `parquet.compression` needs to be considered. ## What changes were proposed in this pull request? Since Hive 1.1, Hive allows users to set parquet compression codec via table-level properties parquet.compression. See the JIRA: https://issues.apache.org/jira/browse/HIVE-7858 . We do support orc.compression for ORC. Thus, for external users, it is more straightforward to support both. See the stackflow question: https://stackoverflow.com/questions/36941122/spark-sql-ignores-parquet-compression-propertie-specified-in-tblproperties In Spark side, our table-level compression conf compression was added by #11464 since Spark 2.0. We need to support both table-level conf. Users might also use session-level conf spark.sql.parquet.compression.codec. The priority rule will be like If other compression codec configuration was found through hive or parquet, the precedence would be compression, parquet.compression, spark.sql.parquet.compression.codec. Acceptable values include: none, uncompressed, snappy, gzip, lzo. The rule for Parquet is consistent with the ORC after the change. Changes: 1.Increased acquiring 'compressionCodecClassName' from `parquet.compression`,and the precedence order is `compression`,`parquet.compression`,`spark.sql.parquet.compression.codec`, just like what we do in `OrcOptions`. 2.Change `spark.sql.parquet.compression.codec` to support "none".Actually in `ParquetOptions`,we do support "none" as equivalent to "uncompressed", but it does not allowed to configured to "none". 3.Change `compressionCode` to `compressionCodecClassName`. ## How was this patch tested? Add test. Author: fjh100456 Closes #20076 from fjh100456/ParquetOptionIssue. (cherry picked from commit 7b78041423b6ee330def2336dfd1ff9ae8469c59) Signed-off-by: gatorsmile --- docs/sql-programming-guide.md | 6 +- .../apache/spark/sql/internal/SQLConf.scala | 14 +- .../datasources/parquet/ParquetOptions.scala | 12 +- ...rquetCompressionCodecPrecedenceSuite.scala | 122 ++++++++++++++++++ 4 files changed, 145 insertions(+), 9 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala diff --git a/docs/sql-programming-guide.md b/docs/sql-programming-guide.md index b50f9360b866c..3ccaaf4d5b1fa 100644 --- a/docs/sql-programming-guide.md +++ b/docs/sql-programming-guide.md @@ -953,8 +953,10 @@ Configuration of Parquet can be done using the `setConf` method on `SparkSession spark.sql.parquet.compression.codec snappy - Sets the compression codec use when writing Parquet files. Acceptable values include: - uncompressed, snappy, gzip, lzo. + Sets the compression codec used when writing Parquet files. If either `compression` or + `parquet.compression` is specified in the table-specific options/properties, the precedence would be + `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`. Acceptable values include: + none, uncompressed, snappy, gzip, lzo. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 80b8965e084a2..7d1217de254a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,11 +325,13 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec use when writing Parquet files. Acceptable values include: " + - "uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + + "`parquet.compression` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) - .checkValues(Set("uncompressed", "snappy", "gzip", "lzo")) + .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) .createWithDefault("snappy") val PARQUET_FILTER_PUSHDOWN_ENABLED = buildConf("spark.sql.parquet.filterPushdown") @@ -366,8 +368,10 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec use when writing ORC files. Acceptable values include: " + - "none, uncompressed, snappy, zlib, lzo.") + .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + + "`orc.compress` is specified in the table-specific options/properties, the precedence" + + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "zlib", "lzo")) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 772d4565de548..ef67ea7d17cea 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.execution.datasources.parquet import java.util.Locale +import org.apache.parquet.hadoop.ParquetOutputFormat import org.apache.parquet.hadoop.metadata.CompressionCodecName import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap @@ -42,8 +43,15 @@ private[parquet] class ParquetOptions( * Acceptable values are defined in [[shortParquetCompressionCodecNames]]. */ val compressionCodecClassName: String = { - val codecName = parameters.getOrElse("compression", - sqlConf.parquetCompressionCodec).toLowerCase(Locale.ROOT) + // `compression`, `parquet.compression`(i.e., ParquetOutputFormat.COMPRESSION), and + // `spark.sql.parquet.compression.codec` + // are in order of precedence from highest to lowest. + val parquetCompressionConf = parameters.get(ParquetOutputFormat.COMPRESSION) + val codecName = parameters + .get("compression") + .orElse(parquetCompressionConf) + .getOrElse(sqlConf.parquetCompressionCodec) + .toLowerCase(Locale.ROOT) if (!shortParquetCompressionCodecNames.contains(codecName)) { val availableCodecs = shortParquetCompressionCodecNames.keys.map(_.toLowerCase(Locale.ROOT)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala new file mode 100644 index 0000000000000..ed8fd2b453456 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompressionCodecPrecedenceSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.io.File + +import scala.collection.JavaConverters._ + +import org.apache.hadoop.fs.Path +import org.apache.parquet.hadoop.ParquetOutputFormat + +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.test.SharedSQLContext + +class ParquetCompressionCodecPrecedenceSuite extends ParquetTest with SharedSQLContext { + test("Test `spark.sql.parquet.compression.codec` config") { + Seq("NONE", "UNCOMPRESSED", "SNAPPY", "GZIP", "LZO").foreach { c => + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> c) { + val expected = if (c == "NONE") "UNCOMPRESSED" else c + val option = new ParquetOptions(Map.empty[String, String], spark.sessionState.conf) + assert(option.compressionCodecClassName == expected) + } + } + } + + test("[SPARK-21786] Test Acquiring 'compressionCodecClassName' for parquet in right order.") { + // When "compression" is configured, it should be the first choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map("compression" -> "uncompressed", ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "UNCOMPRESSED") + } + + // When "compression" is not configured, "parquet.compression" should be the preferred choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map(ParquetOutputFormat.COMPRESSION -> "gzip") + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "GZIP") + } + + // When both "compression" and "parquet.compression" are not configured, + // spark.sql.parquet.compression.codec should be the right choice. + withSQLConf(SQLConf.PARQUET_COMPRESSION.key -> "snappy") { + val props = Map.empty[String, String] + val option = new ParquetOptions(props, spark.sessionState.conf) + assert(option.compressionCodecClassName == "SNAPPY") + } + } + + private def getTableCompressionCodec(path: String): Seq[String] = { + val hadoopConf = spark.sessionState.newHadoopConf() + val codecs = for { + footer <- readAllFootersWithoutSummaryFiles(new Path(path), hadoopConf) + block <- footer.getParquetMetadata.getBlocks.asScala + column <- block.getColumns.asScala + } yield column.getCodec.name() + codecs.distinct + } + + private def createTableWithCompression( + tableName: String, + isPartitioned: Boolean, + compressionCodec: String, + rootDir: File): Unit = { + val options = + s""" + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName', + |'parquet.compression'='$compressionCodec') + """.stripMargin + val partitionCreate = if (isPartitioned) "PARTITIONED BY (p)" else "" + sql( + s""" + |CREATE TABLE $tableName USING Parquet $options $partitionCreate + |AS SELECT 1 AS col1, 2 AS p + """.stripMargin) + } + + private def checkCompressionCodec(compressionCodec: String, isPartitioned: Boolean): Unit = { + withTempDir { tmpDir => + val tempTableName = "TempParquetTable" + withTable(tempTableName) { + createTableWithCompression(tempTableName, isPartitioned, compressionCodec, tmpDir) + val partitionPath = if (isPartitioned) "p=2" else "" + val path = s"${tmpDir.getPath.stripSuffix("/")}/$tempTableName/$partitionPath" + val realCompressionCodecs = getTableCompressionCodec(path) + assert(realCompressionCodecs.forall(_ == compressionCodec)) + } + } + } + + test("Create parquet table with compression") { + Seq(true, false).foreach { isPartitioned => + Seq("UNCOMPRESSED", "SNAPPY", "GZIP").foreach { compressionCodec => + checkCompressionCodec(compressionCodec, isPartitioned) + } + } + } + + test("Create table with unknown compression") { + Seq(true, false).foreach { isPartitioned => + val exception = intercept[IllegalArgumentException] { + checkCompressionCodec("aa", isPartitioned) + } + assert(exception.getMessage.contains("Codec [aa] is not available")) + } + } +} From b66700a5ed9a6e31433bec7961361c382a8b4162 Mon Sep 17 00:00:00 2001 From: hyukjinkwon Date: Sat, 6 Jan 2018 23:08:26 +0800 Subject: [PATCH 45/55] [SPARK-22901][PYTHON][FOLLOWUP] Adds the doc for asNondeterministic for wrapped UDF function ## What changes were proposed in this pull request? This PR wraps the `asNondeterministic` attribute in the wrapped UDF function to set the docstring properly. ```python from pyspark.sql.functions import udf help(udf(lambda x: x).asNondeterministic) ``` Before: ``` Help on function in module pyspark.sql.udf: lambda (END ``` After: ``` Help on function asNondeterministic in module pyspark.sql.udf: asNondeterministic() Updates UserDefinedFunction to nondeterministic. .. versionadded:: 2.3 (END) ``` ## How was this patch tested? Manually tested and a simple test was added. Author: hyukjinkwon Closes #20173 from HyukjinKwon/SPARK-22901-followup. --- python/pyspark/sql/tests.py | 1 + python/pyspark/sql/udf.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 689736d8e6456..122a65b83aef9 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -413,6 +413,7 @@ def test_nondeterministic_udf2(self): pydoc.render_doc(udf(lambda: random.randint(6, 6), IntegerType())) pydoc.render_doc(random_udf) pydoc.render_doc(random_udf1) + pydoc.render_doc(udf(lambda x: x).asNondeterministic) def test_nondeterministic_udf_in_aggregate(self): from pyspark.sql.functions import udf, sum diff --git a/python/pyspark/sql/udf.py b/python/pyspark/sql/udf.py index 5e75eb6545333..5e80ab9165867 100644 --- a/python/pyspark/sql/udf.py +++ b/python/pyspark/sql/udf.py @@ -169,8 +169,8 @@ def wrapper(*args): wrapper.returnType = self.returnType wrapper.evalType = self.evalType wrapper.deterministic = self.deterministic - wrapper.asNondeterministic = lambda: self.asNondeterministic()._wrapped() - + wrapper.asNondeterministic = functools.wraps( + self.asNondeterministic)(lambda: self.asNondeterministic()._wrapped()) return wrapper def asNondeterministic(self): From f9e7b0c8aa9334fa2e467b26516ac8e54a51dc63 Mon Sep 17 00:00:00 2001 From: gatorsmile Date: Sun, 7 Jan 2018 00:19:21 +0800 Subject: [PATCH 46/55] [HOTFIX] Fix style checking failure ## What changes were proposed in this pull request? This PR is to fix the style checking failure. ## How was this patch tested? N/A Author: gatorsmile Closes #20175 from gatorsmile/stylefix. (cherry picked from commit 9a7048b2889bd0fd66e68a0ce3e07e466315a051) Signed-off-by: gatorsmile --- .../org/apache/spark/sql/internal/SQLConf.scala | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 7d1217de254a2..5c61f10bb71ad 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -325,10 +325,11 @@ object SQLConf { .createWithDefault(false) val PARQUET_COMPRESSION = buildConf("spark.sql.parquet.compression.codec") - .doc("Sets the compression codec used when writing Parquet files. If either `compression` or" + - "`parquet.compression` is specified in the table-specific options/properties, the precedence" + - "would be `compression`, `parquet.compression`, `spark.sql.parquet.compression.codec`." + - "Acceptable values include: none, uncompressed, snappy, gzip, lzo.") + .doc("Sets the compression codec used when writing Parquet files. If either `compression` or " + + "`parquet.compression` is specified in the table-specific options/properties, the " + + "precedence would be `compression`, `parquet.compression`, " + + "`spark.sql.parquet.compression.codec`. Acceptable values include: none, uncompressed, " + + "snappy, gzip, lzo.") .stringConf .transform(_.toLowerCase(Locale.ROOT)) .checkValues(Set("none", "uncompressed", "snappy", "gzip", "lzo")) @@ -368,8 +369,8 @@ object SQLConf { .createWithDefault(true) val ORC_COMPRESSION = buildConf("spark.sql.orc.compression.codec") - .doc("Sets the compression codec used when writing ORC files. If either `compression` or" + - "`orc.compress` is specified in the table-specific options/properties, the precedence" + + .doc("Sets the compression codec used when writing ORC files. If either `compression` or " + + "`orc.compress` is specified in the table-specific options/properties, the precedence " + "would be `compression`, `orc.compress`, `spark.sql.orc.compression.codec`." + "Acceptable values include: none, uncompressed, snappy, zlib, lzo.") .stringConf From 285d342c406cf304931844665e56725ef1a848e0 Mon Sep 17 00:00:00 2001 From: Takeshi Yamamuro Date: Sun, 7 Jan 2018 13:42:01 +0800 Subject: [PATCH 47/55] [SPARK-22973][SQL] Fix incorrect results of Casting Map to String ## What changes were proposed in this pull request? This pr fixed the issue when casting maps into strings; ``` scala> Seq(Map(1 -> "a", 2 -> "b")).toDF("a").write.saveAsTable("t") scala> sql("SELECT cast(a as String) FROM t").show(false) +----------------------------------------------------------------+ |a | +----------------------------------------------------------------+ |org.apache.spark.sql.catalyst.expressions.UnsafeMapData38bdd75d| +----------------------------------------------------------------+ ``` This pr modified the result into; ``` +----------------+ |a | +----------------+ |[1 -> a, 2 -> b]| +----------------+ ``` ## How was this patch tested? Added tests in `CastSuite`. Author: Takeshi Yamamuro Closes #20166 from maropu/SPARK-22973. (cherry picked from commit 18e94149992618a2b4e6f0fd3b3f4594e1745224) Signed-off-by: Wenchen Fan --- .../spark/sql/catalyst/expressions/Cast.scala | 89 +++++++++++++++++++ .../sql/catalyst/expressions/CastSuite.scala | 28 ++++++ 2 files changed, 117 insertions(+) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index d4fc5e0f168a7..f2de4c8e30bec 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -228,6 +228,37 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String builder.append("]") builder.build() }) + case MapType(kt, vt, _) => + buildCast[MapData](_, map => { + val builder = new UTF8StringBuilder + builder.append("[") + if (map.numElements > 0) { + val keyArray = map.keyArray() + val valueArray = map.valueArray() + val keyToUTF8String = castToString(kt) + val valueToUTF8String = castToString(vt) + builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(0)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String]) + } + var i = 1 + while (i < map.numElements) { + builder.append(", ") + builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String]) + builder.append(" ->") + if (!valueArray.isNullAt(i)) { + builder.append(" ") + builder.append(valueToUTF8String(valueArray.get(i, vt)) + .asInstanceOf[UTF8String]) + } + i += 1 + } + } + builder.append("]") + builder.build() + }) case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString)) } @@ -654,6 +685,53 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String """.stripMargin } + private def writeMapToStringBuilder( + kt: DataType, + vt: DataType, + map: String, + buffer: String, + ctx: CodegenContext): String = { + + def dataToStringFunc(func: String, dataType: DataType) = { + val funcName = ctx.freshName(func) + val dataToStringCode = castToStringCode(dataType, ctx) + ctx.addNewFunction(funcName, + s""" + |private UTF8String $funcName(${ctx.javaType(dataType)} data) { + | UTF8String dataStr = null; + | ${dataToStringCode("data", "dataStr", null /* resultIsNull won't be used */)} + | return dataStr; + |} + """.stripMargin) + } + + val keyToStringFunc = dataToStringFunc("keyToString", kt) + val valueToStringFunc = dataToStringFunc("valueToString", vt) + val loopIndex = ctx.freshName("loopIndex") + s""" + |$buffer.append("["); + |if ($map.numElements() > 0) { + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, "0")})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt(0)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc(${ctx.getValue(s"$map.valueArray()", vt, "0")})); + | } + | for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) { + | $buffer.append(", "); + | $buffer.append($keyToStringFunc(${ctx.getValue(s"$map.keyArray()", kt, loopIndex)})); + | $buffer.append(" ->"); + | if (!$map.valueArray().isNullAt($loopIndex)) { + | $buffer.append(" "); + | $buffer.append($valueToStringFunc( + | ${ctx.getValue(s"$map.valueArray()", vt, loopIndex)})); + | } + | } + |} + |$buffer.append("]"); + """.stripMargin + } + private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = { from match { case BinaryType => @@ -676,6 +754,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String |$evPrim = $buffer.build(); """.stripMargin } + case MapType(kt, vt, _) => + (c, evPrim, evNull) => { + val buffer = ctx.freshName("buffer") + val bufferClass = classOf[UTF8StringBuilder].getName + val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx) + s""" + |$bufferClass $buffer = new $bufferClass(); + |$writeMapElemCode; + |$evPrim = $buffer.build(); + """.stripMargin + } case _ => (c, evPrim, evNull) => s"$evPrim = UTF8String.fromString(String.valueOf($c));" } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala index e3ed7171defd8..1445bb8a97d40 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CastSuite.scala @@ -878,4 +878,32 @@ class CastSuite extends SparkFunSuite with ExpressionEvalHelper { StringType) checkEvaluation(ret8, "[[[a], [b, c]], [[d]]]") } + + test("SPARK-22973 Cast map to string") { + val ret1 = cast(Literal.create(Map(1 -> "a", 2 -> "b", 3 -> "c")), StringType) + checkEvaluation(ret1, "[1 -> a, 2 -> b, 3 -> c]") + val ret2 = cast( + Literal.create(Map("1" -> "a".getBytes, "2" -> null, "3" -> "c".getBytes)), + StringType) + checkEvaluation(ret2, "[1 -> a, 2 ->, 3 -> c]") + val ret3 = cast( + Literal.create(Map( + 1 -> Date.valueOf("2014-12-03"), + 2 -> Date.valueOf("2014-12-04"), + 3 -> Date.valueOf("2014-12-05"))), + StringType) + checkEvaluation(ret3, "[1 -> 2014-12-03, 2 -> 2014-12-04, 3 -> 2014-12-05]") + val ret4 = cast( + Literal.create(Map( + 1 -> Timestamp.valueOf("2014-12-03 13:01:00"), + 2 -> Timestamp.valueOf("2014-12-04 15:05:00"))), + StringType) + checkEvaluation(ret4, "[1 -> 2014-12-03 13:01:00, 2 -> 2014-12-04 15:05:00]") + val ret5 = cast( + Literal.create(Map( + 1 -> Array(1, 2, 3), + 2 -> Array(4, 5, 6))), + StringType) + checkEvaluation(ret5, "[1 -> [1, 2, 3], 2 -> [4, 5, 6]]") + } } From 5b150bc643e301ec063704303aebeea87de02e83 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Mon, 8 Jan 2018 14:18:24 +0800 Subject: [PATCH 48/55] Fix test issue Fix test issue --- .../sql/hive/CompressionCodecSuite.scala | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 71aa95da17e3e..77689eb7c5d63 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -102,24 +102,24 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo val partitionCreate = if (isPartitioned) "PARTITIONED BY (p string)" else "" sql( s""" - |CREATE TABLE $tableName(a int) - |$partitionCreate - |STORED AS $format - |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' - |$tblProperties - """.stripMargin) + |CREATE TABLE $tableName(a int) + |$partitionCreate + |STORED AS $format + |LOCATION '${rootDir.toURI.toString.stripSuffix("/")}/$tableName' + |$tblProperties + """.stripMargin) } private def writeDataToTable( tableName: String, partition: Option[String]): Unit = { - val partitionInsert = partition.map(p => s"partition ($p)").mkString + val partitionInsert = partition.map(p => s"partition (p='$p')").mkString sql( s""" - |INSERT INTO TABLE $tableName - |$partitionInsert - |SELECT * FROM table_source - """.stripMargin) + |INSERT INTO TABLE $tableName + |$partitionInsert + |SELECT * FROM table_source + """.stripMargin) } private def getTableSize(path: String): Long = { @@ -128,6 +128,11 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo files.map(_.length()).sum } + private def getTablePartitionPath(dir: File, tableName: String, partition: Option[String]) = { + val partitionPath = partition.map(p => s"p=$p").mkString + s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" + } + private def getUncompressedDataSizeByFormat( format: String, isPartitioned: Boolean): Long = { var totalSize = 0L @@ -137,9 +142,9 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo withTempDir { tmpDir => withTable(tableName) { createTable(tmpDir, tableName, isPartitioned, format, Option(codecName)) - val partition = if (isPartitioned) Some("p='test'") else None + val partition = if (isPartitioned) Some("test") else None writeDataToTable(tableName, partition) - val path = s"${tmpDir.getPath.stripSuffix("/")}/$tableName/${partition.mkString}" + val path = getTablePartitionPath(tmpDir, tableName, partition) totalSize = getTableSize(path) } } @@ -157,9 +162,9 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo withTempDir { tmpDir => withTable(tableName) { createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) - val partition = if (isPartitioned) Some("p='test'") else None + val partition = if (isPartitioned) Some("test") else None writeDataToTable(tableName, partition) - val path = s"${tmpDir.getPath.stripSuffix("/")}/$tableName/${partition.mkString}" + val path = getTablePartitionPath(tmpDir, tableName, partition) val relCompressionCodecs = getTableCompressionCodec(path, format) assert(relCompressionCodecs.length == 1) val tableSize = getTableSize(path) @@ -289,12 +294,12 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo createTable(tmpDir, tableName, isPartitioned, format, None) withTable(tableName) { compressCodecs.foreach { compressionCodec => - val partition = if (isPartitioned) Some(s"p='$compressionCodec'") else None + val partition = if (isPartitioned) Some(compressionCodec) else None withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, getSparkCompressionConfName(format) -> compressionCodec ) { writeDataToTable(tableName, partition) } } - val tablePath = s"${tmpDir.getPath.stripSuffix("/")}/$tableName" + val tablePath = getTablePartitionPath(tmpDir, tableName, None) val relCompressionCodecs = if (isPartitioned) compressCodecs.flatMap { codec => getTableCompressionCodec(s"$tablePath/p=$codec", format) From 4b89b44b2b4a9a1ffee01da13a3e55cdd3aa260d Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Thu, 11 Jan 2018 10:28:42 +0800 Subject: [PATCH 49/55] consider the precedence of `hive.exec.compress.output` --- .../sql/hive/execution/HiveOptions.scala | 2 +- .../sql/hive/execution/SaveAsHiveFile.scala | 24 ++++++++++++------- 2 files changed, 16 insertions(+), 10 deletions(-) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala index 192f6bbebc7d1..802ddafdbee4d 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveOptions.scala @@ -114,7 +114,7 @@ object HiveOptions { def getHiveWriteCompression(tableInfo: TableDesc, sqlConf: SQLConf): Option[(String, String)] = { val tableProps = tableInfo.getProperties.asScala.toMap - tableInfo.getOutputFileFormatClassName.toLowerCase match { + tableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { case formatName if formatName.endsWith("parquetoutputformat") => val compressionCodec = new ParquetOptions(tableProps, sqlConf).compressionCodecClassName Option((ParquetOutputFormat.COMPRESSION, compressionCodec)) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala index 4ba486ee7ff53..e484356906e87 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/SaveAsHiveFile.scala @@ -55,24 +55,30 @@ private[hive] trait SaveAsHiveFile extends DataWritingCommand { customPartitionLocations: Map[TablePartitionSpec, String] = Map.empty, partitionAttributes: Seq[Attribute] = Nil): Set[String] = { - val isCompressed = hadoopConf.get("hive.exec.compress.output", "false").toBoolean + val isCompressed = + fileSinkConf.getTableInfo.getOutputFileFormatClassName.toLowerCase(Locale.ROOT) match { + case formatName if formatName.endsWith("orcoutputformat") => + // For ORC,"mapreduce.output.fileoutputformat.compress", + // "mapreduce.output.fileoutputformat.compress.codec", and + // "mapreduce.output.fileoutputformat.compress.type" + // have no impact because it uses table properties to store compression information. + false + case _ => hadoopConf.get("hive.exec.compress.output", "false").toBoolean + } + if (isCompressed) { - // Please note that isCompressed, "mapreduce.output.fileoutputformat.compress", - // "mapreduce.output.fileoutputformat.compress.codec", and - // "mapreduce.output.fileoutputformat.compress.type" - // have no impact on ORC because it uses table properties to store compression information. hadoopConf.set("mapreduce.output.fileoutputformat.compress", "true") fileSinkConf.setCompressed(true) fileSinkConf.setCompressCodec(hadoopConf .get("mapreduce.output.fileoutputformat.compress.codec")) fileSinkConf.setCompressType(hadoopConf .get("mapreduce.output.fileoutputformat.compress.type")) + } else { + // Set compression by priority + HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) + .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } } - // Set compression by priority - HiveOptions.getHiveWriteCompression(fileSinkConf.getTableInfo, sparkSession.sessionState.conf) - .foreach { case (compression, codec) => hadoopConf.set(compression, codec) } - val committer = FileCommitProtocol.instantiate( sparkSession.sessionState.conf.fileCommitProtocolClass, jobId = java.util.UUID.randomUUID().toString, From 6cf32e075043eb87dabcae8b88c242ea8f211d8d Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Fri, 19 Jan 2018 17:59:49 +0800 Subject: [PATCH 50/55] Resume to private and add public function --- .../spark/sql/execution/datasources/orc/OrcOptions.scala | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala index faa1f4a4943f7..0ad3862f6cf01 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcOptions.scala @@ -61,10 +61,12 @@ class OrcOptions( object OrcOptions { // The ORC compression short names - val shortOrcCompressionCodecNames = Map( + private val shortOrcCompressionCodecNames = Map( "none" -> "NONE", "uncompressed" -> "NONE", "snappy" -> "SNAPPY", "zlib" -> "ZLIB", "lzo" -> "LZO") + + def getORCCompressionCodecName(name: String): String = shortOrcCompressionCodecNames(name) } From 365c5bf81abcc5643ebb7908b514ec373bedc5d5 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Fri, 19 Jan 2018 18:00:28 +0800 Subject: [PATCH 51/55] Resume to private and add public function --- .../sql/execution/datasources/parquet/ParquetOptions.scala | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index 383599b7fce3d..cb88de9aee57e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -76,10 +76,13 @@ object ParquetOptions { val MERGE_SCHEMA = "mergeSchema" // The parquet compression short names - val shortParquetCompressionCodecNames = Map( + private val shortParquetCompressionCodecNames = Map( "none" -> CompressionCodecName.UNCOMPRESSED, "uncompressed" -> CompressionCodecName.UNCOMPRESSED, "snappy" -> CompressionCodecName.SNAPPY, "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) + + def getParquetCompressionCodecName(name: String): String = + shortParquetCompressionCodecNames(name).name() } From 99271d670a0aed444ad624d56304d94490eed0cb Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Fri, 19 Jan 2018 18:01:27 +0800 Subject: [PATCH 52/55] Fix test issue --- .../sql/hive/CompressionCodecSuite.scala | 184 ++++++++++-------- 1 file changed, 106 insertions(+), 78 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 77689eb7c5d63..4d72e61ff2fd5 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -48,7 +48,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo } } - private val maxRecordNum = 500 + private val maxRecordNum = 50 private def getConvertMetastoreConfName(format: String): String = format.toLowerCase match { case "parquet" => HiveUtils.CONVERT_METASTORE_PARQUET.key @@ -67,8 +67,8 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo private def normalizeCodecName(format: String, name: String): String = { format.toLowerCase match { - case "parquet" => ParquetOptions.shortParquetCompressionCodecNames(name).name() - case "orc" => OrcOptions.shortOrcCompressionCodecNames(name) + case "parquet" => ParquetOptions.getParquetCompressionCodecName(name) + case "orc" => OrcOptions.getORCCompressionCodecName(name) } } @@ -80,7 +80,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo block <- footer.getParquetMetadata.getBlocks.asScala column <- block.getColumns.asScala } yield column.getCodec.name() - case "orc" => new File(path).listFiles().filter{ file => + case "orc" => new File(path).listFiles().filter { file => file.isFile && !file.getName.endsWith(".crc") && file.getName != "_SUCCESS" }.map { orcFile => OrcFileOperator.getFileReader(orcFile.toPath.toString).get.getCompression.toString @@ -112,8 +112,8 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo private def writeDataToTable( tableName: String, - partition: Option[String]): Unit = { - val partitionInsert = partition.map(p => s"partition (p='$p')").mkString + partitionValue: Option[String]): Unit = { + val partitionInsert = partitionValue.map(p => s"partition (p='$p')").mkString sql( s""" |INSERT INTO TABLE $tableName @@ -122,29 +122,69 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo """.stripMargin) } + private def writeDateToTableUsingCTAS( + rootDir: File, + tableName: String, + partitionValue: Option[String], + format: String, + compressionCodec: Option[String]): Unit = { + val partitionCreate = partitionValue.map(p => s"PARTITIONED BY (p)").mkString + val compressionOption = compressionCodec.map { codec => + s",'${getHiveCompressPropName(format)}'='$codec'" + }.mkString + val partitionSelect = partitionValue.map(p => s",'$p' AS p").mkString + sql( + s""" + |CREATE TABLE $tableName + |USING $format + |OPTIONS('path'='${rootDir.toURI.toString.stripSuffix("/")}/$tableName' $compressionOption) + |$partitionCreate + |AS SELECT * $partitionSelect FROM table_source + """.stripMargin) + } + + private def getPreparedTablePath( + tmpDir: File, + tableName: String, + isPartitioned: Boolean, + format: String, + compressionCodec: Option[String], + usingCTAS: Boolean): String = { + val partitionValue = if (isPartitioned) Some("test") else None + if (usingCTAS) { + writeDateToTableUsingCTAS(tmpDir, tableName, partitionValue, format, compressionCodec) + } else { + createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) + writeDataToTable(tableName, partitionValue) + } + getTablePartitionPath(tmpDir, tableName, partitionValue) + } + private def getTableSize(path: String): Long = { val dir = new File(path) val files = dir.listFiles().filter(_.getName.startsWith("part-")) files.map(_.length()).sum } - private def getTablePartitionPath(dir: File, tableName: String, partition: Option[String]) = { - val partitionPath = partition.map(p => s"p=$p").mkString + private def getTablePartitionPath( + dir: File, + tableName: String, + partitionValue: Option[String]) = { + val partitionPath = partitionValue.map(p => s"p=$p").mkString s"${dir.getPath.stripSuffix("/")}/$tableName/$partitionPath" } private def getUncompressedDataSizeByFormat( - format: String, isPartitioned: Boolean): Long = { + format: String, isPartitioned: Boolean, usingCTAS: Boolean): Long = { var totalSize = 0L val tableName = s"tbl_$format" val codecName = normalizeCodecName(format, "uncompressed") withSQLConf(getSparkCompressionConfName(format) -> codecName) { withTempDir { tmpDir => withTable(tableName) { - createTable(tmpDir, tableName, isPartitioned, format, Option(codecName)) - val partition = if (isPartitioned) Some("test") else None - writeDataToTable(tableName, partition) - val path = getTablePartitionPath(tmpDir, tableName, partition) + val compressionCodec = Option(codecName) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) totalSize = getTableSize(path) } } @@ -156,15 +196,15 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo private def checkCompressionCodecForTable( format: String, isPartitioned: Boolean, - compressionCodec: Option[String]) + compressionCodec: Option[String], + usingCTAS: Boolean) (assertion: (String, Long) => Unit): Unit = { - val tableName = s"tbl_$format$isPartitioned" + val tableName = + if (usingCTAS) s"tbl_$format$isPartitioned" else s"tbl_$format${isPartitioned}_CAST" withTempDir { tmpDir => withTable(tableName) { - createTable(tmpDir, tableName, isPartitioned, format, compressionCodec) - val partition = if (isPartitioned) Some("test") else None - writeDataToTable(tableName, partition) - val path = getTablePartitionPath(tmpDir, tableName, partition) + val path = getPreparedTablePath( + tmpDir, tableName, isPartitioned, format, compressionCodec, usingCTAS) val relCompressionCodecs = getTableCompressionCodec(path, format) assert(relCompressionCodecs.length == 1) val tableSize = getTableSize(path) @@ -177,6 +217,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo format: String, isPartitioned: Boolean, convertMetastore: Boolean, + usingCTAS: Boolean, compressionCodecs: List[String], tableCompressionCodecs: List[String]) (assertionCompressionCodec: (Option[String], String, String, Long) => Unit): Unit = { @@ -186,9 +227,10 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo withSQLConf(getSparkCompressionConfName(format) -> sessionCompressionCodec) { // 'tableCompression = null' means no table-level compression val compression = Option(tableCompression) - checkCompressionCodecForTable(format, isPartitioned, compression) { - case (realCompressionCodec, tableSize) => assertionCompressionCodec(compression, - sessionCompressionCodec, realCompressionCodec, tableSize) + checkCompressionCodecForTable(format, isPartitioned, compression, usingCTAS) { + case (realCompressionCodec, tableSize) => + assertionCompressionCodec( + compression, sessionCompressionCodec, realCompressionCodec, tableSize) } } } @@ -203,55 +245,35 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodec: String, isPartitioned: Boolean, convertMetastore: Boolean, + usingCTAS: Boolean, tableSize: Long): Boolean = { - format match { - case "parquet" => - val uncompressedSize = if (!convertMetastore || isPartitioned) { - getUncompressedDataSizeByFormat(format, isPartitioned = true) - } else { - getUncompressedDataSizeByFormat(format, isPartitioned = false) - } - - if (compressionCodec == "UNCOMPRESSED") { - tableSize == uncompressedSize - } else { - tableSize != uncompressedSize - } - case "orc" => - val uncompressedSize = if (!convertMetastore || isPartitioned) { - getUncompressedDataSizeByFormat(format, isPartitioned = true) - } else { - getUncompressedDataSizeByFormat(format, isPartitioned = false) - } - - if (compressionCodec == "NONE") { - tableSize == uncompressedSize - } else { - tableSize != uncompressedSize - } - case _ => false + val uncompressedSize = getUncompressedDataSizeByFormat(format, isPartitioned, usingCTAS) + compressionCodec match { + case "UNCOMPRESSED" if format == "parquet" => tableSize == uncompressedSize + case "NONE" if format == "orc" => tableSize == uncompressedSize + case _ => tableSize != uncompressedSize } } def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - checkTableCompressionCodecForCodecs( - format, - isPartitioned, - convertMetastore, - compressionCodecs = compressCodecs, - tableCompressionCodecs = compressCodecs) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // For non-partitioned table and when convertMetastore is false, Expect session-level - // take effect, and in other cases expect table-level take effect - val expectCompressionCodec = - if (convertMetastore && !isPartitioned) sessionCompressionCodec - else tableCompressionCodec.get - - assert(expectCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, expectCompressionCodec, - isPartitioned, convertMetastore, tableSize)) + Seq(true, false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = compressCodecs) { + case + (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // After SPARK-22926, table-level will take effect + val expectCompressionCodec = tableCompressionCodec.get + assert(expectCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, expectCompressionCodec, + isPartitioned, convertMetastore, usingCTAS, tableSize)) + } } } } @@ -260,17 +282,21 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - checkTableCompressionCodecForCodecs( - format, - isPartitioned, - convertMetastore, - compressionCodecs = compressCodecs, - tableCompressionCodecs = List(null)) { - case (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // Always expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, - isPartitioned, convertMetastore, tableSize)) + Seq(true, false).foreach { usingCTAS => + checkTableCompressionCodecForCodecs( + format, + isPartitioned, + convertMetastore, + usingCTAS, + compressionCodecs = compressCodecs, + tableCompressionCodecs = List(null)) { + case + (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + // Always expect session-level take effect + assert(sessionCompressionCodec == realCompressionCodec) + assert(checkTableSize(format, sessionCompressionCodec, + isPartitioned, convertMetastore, usingCTAS, tableSize)) + } } } } @@ -294,16 +320,18 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo createTable(tmpDir, tableName, isPartitioned, format, None) withTable(tableName) { compressCodecs.foreach { compressionCodec => - val partition = if (isPartitioned) Some(compressionCodec) else None + val partitionValue = if (isPartitioned) Some(compressionCodec) else None withSQLConf(getConvertMetastoreConfName(format) -> convertMetastore.toString, getSparkCompressionConfName(format) -> compressionCodec - ) { writeDataToTable(tableName, partition) } + ) { writeDataToTable(tableName, partitionValue) } } val tablePath = getTablePartitionPath(tmpDir, tableName, None) val relCompressionCodecs = if (isPartitioned) compressCodecs.flatMap { codec => getTableCompressionCodec(s"$tablePath/p=$codec", format) - } else getTableCompressionCodec(tablePath, format) + } else { + getTableCompressionCodec(tablePath, format) + } assert(relCompressionCodecs.distinct.sorted == compressCodecs.sorted) val recordsNum = sql(s"SELECT * from $tableName").count() From 2b9dfbe3ed6beb48e1b5b13a85a0e1e9b0fffc98 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 13:10:37 +0800 Subject: [PATCH 53/55] Fix test issue --- .../sql/hive/CompressionCodecSuite.scala | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index 4d72e61ff2fd5..b51ccdf200a74 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -258,7 +258,8 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - Seq(true, false).foreach { usingCTAS => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, @@ -266,13 +267,16 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo usingCTAS, compressionCodecs = compressCodecs, tableCompressionCodecs = compressCodecs) { - case - (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => - // After SPARK-22926, table-level will take effect - val expectCompressionCodec = tableCompressionCodec.get - assert(expectCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, expectCompressionCodec, - isPartitioned, convertMetastore, usingCTAS, tableSize)) + case (tableCodec, sessionCodec, realCodec, tableSize) => + // For non-partitioned table and when convertMetastore is true, Expect session-level + // take effect, and in other cases expect table-level take effect + // TODO: It should always be table-level taking effect when the bug(SPARK-22926) + // is fixed + val expectCodec = + if (convertMetastore && !isPartitioned) sessionCodec else tableCodec.get + assert(expectCodec == realCodec) + assert(checkTableSize( + format, expectCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) } } } @@ -282,7 +286,8 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo def checkForTableWithoutCompressProp(format: String, compressCodecs: List[String]): Unit = { Seq(true, false).foreach { isPartitioned => Seq(true, false).foreach { convertMetastore => - Seq(true, false).foreach { usingCTAS => + // TODO: Also verify CTAS(usingCTAS=true) cases when the bug(SPARK-22926) is fixed. + Seq(false).foreach { usingCTAS => checkTableCompressionCodecForCodecs( format, isPartitioned, @@ -291,11 +296,11 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo compressionCodecs = compressCodecs, tableCompressionCodecs = List(null)) { case - (tableCompressionCodec, sessionCompressionCodec, realCompressionCodec, tableSize) => + (tableCodec, sessionCodec, realCodec, tableSize) => // Always expect session-level take effect - assert(sessionCompressionCodec == realCompressionCodec) - assert(checkTableSize(format, sessionCompressionCodec, - isPartitioned, convertMetastore, usingCTAS, tableSize)) + assert(sessionCodec == realCodec) + assert(checkTableSize( + format, sessionCodec, isPartitioned, convertMetastore, usingCTAS, tableSize)) } } } @@ -326,14 +331,14 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo ) { writeDataToTable(tableName, partitionValue) } } val tablePath = getTablePartitionPath(tmpDir, tableName, None) - val relCompressionCodecs = + val realCompressionCodecs = if (isPartitioned) compressCodecs.flatMap { codec => getTableCompressionCodec(s"$tablePath/p=$codec", format) } else { getTableCompressionCodec(tablePath, format) } - assert(relCompressionCodecs.distinct.sorted == compressCodecs.sorted) + assert(realCompressionCodecs.distinct.sorted == compressCodecs.sorted) val recordsNum = sql(s"SELECT * from $tableName").count() assert(recordsNum == maxRecordNum * compressCodecs.length) } From 5b5e1df983af6ff03ec6ef6c83208c8b25af93e2 Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 13:10:42 +0800 Subject: [PATCH 54/55] Fix style issue --- .../sql/execution/datasources/parquet/ParquetOptions.scala | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala index cb88de9aee57e..f36a89a4c3c5f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetOptions.scala @@ -83,6 +83,7 @@ object ParquetOptions { "gzip" -> CompressionCodecName.GZIP, "lzo" -> CompressionCodecName.LZO) - def getParquetCompressionCodecName(name: String): String = + def getParquetCompressionCodecName(name: String): String = { shortParquetCompressionCodecNames(name).name() + } } From 118f7880bdcf26ba7394a2cc7fac2e0eae707d6f Mon Sep 17 00:00:00 2001 From: fjh100456 Date: Sat, 20 Jan 2018 18:13:24 +0800 Subject: [PATCH 55/55] Fix style issue --- .../org/apache/spark/sql/hive/CompressionCodecSuite.scala | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala index b51ccdf200a74..d10a6f25c64fc 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CompressionCodecSuite.scala @@ -295,8 +295,7 @@ class CompressionCodecSuite extends TestHiveSingleton with ParquetTest with Befo usingCTAS, compressionCodecs = compressCodecs, tableCompressionCodecs = List(null)) { - case - (tableCodec, sessionCodec, realCodec, tableSize) => + case (tableCodec, sessionCodec, realCodec, tableSize) => // Always expect session-level take effect assert(sessionCodec == realCodec) assert(checkTableSize(