diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index e33d0d8e29d49..97e0c9edeab48 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -2642,6 +2642,7 @@ generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) { #' #' Return a new SparkDataFrame containing the union of rows in this SparkDataFrame #' and another SparkDataFrame. This is equivalent to \code{UNION ALL} in SQL. +#' Input SparkDataFrames can have different schemas (names and data types). #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2685,7 +2686,8 @@ setMethod("unionAll", #' Union two or more SparkDataFrames #' -#' Union two or more SparkDataFrames. This is equivalent to \code{UNION ALL} in SQL. +#' Union two or more SparkDataFrames by row. As in R's \code{rbind}, this method +#' requires that the input SparkDataFrames have the same column names. #' #' Note: This does not remove duplicate rows across the two SparkDataFrames. #' @@ -2709,6 +2711,10 @@ setMethod("unionAll", setMethod("rbind", signature(... = "SparkDataFrame"), function(x, ..., deparse.level = 1) { + nm <- lapply(list(x, ...), names) + if (length(unique(nm)) != 1) { + stop("Names of input data frames are different.") + } if (nargs() == 3) { union(x, ...) } else { diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 7c096597fea66..9735fe3201553 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1850,6 +1850,13 @@ test_that("union(), rbind(), except(), and intersect() on a DataFrame", { expect_equal(count(unioned2), 12) expect_equal(first(unioned2)$name, "Michael") + df3 <- df2 + names(df3)[1] <- "newName" + expect_error(rbind(df, df3), + "Names of input data frames are different.") + expect_error(rbind(df, df2, df3), + "Names of input data frames are different.") + excepted <- arrange(except(df, df2), desc(df$age)) expect_is(unioned, "SparkDataFrame") expect_equal(count(excepted), 2) @@ -2585,8 +2592,8 @@ test_that("coalesce, repartition, numPartitions", { df2 <- repartition(df1, 10) expect_equal(getNumPartitions(df2), 10) - expect_equal(getNumPartitions(coalesce(df2, 13)), 5) - expect_equal(getNumPartitions(coalesce(df2, 7)), 5) + expect_equal(getNumPartitions(coalesce(df2, 13)), 10) + expect_equal(getNumPartitions(coalesce(df2, 7)), 7) expect_equal(getNumPartitions(coalesce(df2, 3)), 3) }) diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 10a7cb1d06659..4c28075bd9386 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -850,11 +850,23 @@ public UTF8String translate(Map dict) { return fromString(sb.toString()); } - private int getDigit(byte b) { - if (b >= '0' && b <= '9') { - return b - '0'; - } - throw new NumberFormatException(toString()); + /** + * Wrapper over `long` to allow result of parsing long from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + */ + public static class LongWrapper { + public long value = 0; + } + + /** + * Wrapper over `int` to allow result of parsing integer from string to be accessed via reference. + * This is done solely for better performance and is not expected to be used by end users. + * + * {@link LongWrapper} could have been used here but using `int` directly save the extra cost of + * conversion from `long` -> `int` + */ + public static class IntWrapper { + public int value = 0; } /** @@ -862,14 +874,18 @@ private int getDigit(byte b) { * * Note that, in this method we accumulate the result in negative format, and convert it to * positive format at the end, if this string is not started with '-'. This is because min value - * is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and - * Integer.MIN_VALUE is '-2147483648'. + * is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and + * Long.MIN_VALUE is '-9223372036854775808'. * * This code is mostly copied from LazyLong.parseLong in Hive. + * + * @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would + * be set in `toLongResult` + * @return true if the parsing was successful else false */ - public long toLong() { + public boolean toLong(LongWrapper toLongResult) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -878,7 +894,7 @@ public long toLong() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -897,20 +913,25 @@ public long toLong() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop. if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we - // can just use `result > 0` to check overflow. If result overflows, we should stop and throw - // exception. + // can just use `result > 0` to check overflow. If result overflows, we should stop. if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -918,8 +939,9 @@ public long toLong() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -927,11 +949,12 @@ public long toLong() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - return result; + toLongResult.value = result; + return true; } /** @@ -946,10 +969,14 @@ public long toLong() { * * Note that, this method is almost same as `toLong`, but we leave it duplicated for performance * reasons, like Hive does. + * + * @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would + * be set in `intWrapper` + * @return true if the parsing was successful else false */ - public int toInt() { + public boolean toInt(IntWrapper intWrapper) { if (numBytes == 0) { - throw new NumberFormatException("Empty string"); + return false; } byte b = getByte(0); @@ -958,7 +985,7 @@ public int toInt() { if (negative || b == '+') { offset++; if (numBytes == 1) { - throw new NumberFormatException(toString()); + return false; } } @@ -977,20 +1004,25 @@ public int toInt() { break; } - int digit = getDigit(b); + int digit; + if (b >= '0' && b <= '9') { + digit = b - '0'; + } else { + return false; + } + // We are going to process the new digit and accumulate the result. However, before doing // this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then - // result * 10 will definitely be smaller than minValue, and we can stop and throw exception. + // result * 10 will definitely be smaller than minValue, and we can stop if (result < stopValue) { - throw new NumberFormatException(toString()); + return false; } result = result * radix - digit; // Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix), - // we can just use `result > 0` to check overflow. If result overflows, we should stop and - // throw exception. + // we can just use `result > 0` to check overflow. If result overflows, we should stop if (result > 0) { - throw new NumberFormatException(toString()); + return false; } } @@ -998,8 +1030,9 @@ public int toInt() { // part will not change the number, but we will verify that the fractional part // is well formed. while (offset < numBytes) { - if (getDigit(getByte(offset)) == -1) { - throw new NumberFormatException(toString()); + byte currentByte = getByte(offset); + if (currentByte < '0' || currentByte > '9') { + return false; } offset++; } @@ -1007,31 +1040,33 @@ public int toInt() { if (!negative) { result = -result; if (result < 0) { - throw new NumberFormatException(toString()); + return false; } } - - return result; + intWrapper.value = result; + return true; } - public short toShort() { - int intValue = toInt(); - short result = (short) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toShort(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + short result = (short) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } - public byte toByte() { - int intValue = toInt(); - byte result = (byte) intValue; - if (result != intValue) { - throw new NumberFormatException(toString()); + public boolean toByte(IntWrapper intWrapper) { + if (toInt(intWrapper)) { + int intValue = intWrapper.value; + byte result = (byte) intValue; + if (result == intValue) { + return true; + } } - - return result; + return false; } @Override diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 6f6e0ef0e4855..c376371abdf90 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -22,9 +22,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.charset.StandardCharsets; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; +import java.util.*; import com.google.common.collect.ImmutableMap; import org.apache.spark.unsafe.Platform; @@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException { .writeTo(outputStream); assertEquals("大千世界", outputStream.toString("UTF-8")); } + + @Test + public void testToShort() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (short) 1); + inputToExpectedOutput.put("+1", (short) 1); + inputToExpectedOutput.put("-1", (short) -1); + inputToExpectedOutput.put("0", (short) 0); + inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111); + inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + short value = (short) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper wrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper)); + assertEquals((short) entry.getValue(), wrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper)); + } + } + + @Test + public void testToByte() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", (byte) 1); + inputToExpectedOutput.put("+1",(byte) 1); + inputToExpectedOutput.put("-1", (byte) -1); + inputToExpectedOutput.put("0", (byte) 0); + inputToExpectedOutput.put("111.12345678901234567890", (byte) 111); + inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + byte value = (byte) rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper)); + assertEquals((byte) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper)); + } + } + + @Test + public void testToInt() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1); + inputToExpectedOutput.put("+1", 1); + inputToExpectedOutput.put("-1", -1); + inputToExpectedOutput.put("0", 0); + inputToExpectedOutput.put("11111.1234567", 11111); + inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + int value = rand.nextInt(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + IntWrapper intWrapper = new IntWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper)); + assertEquals((int) entry.getValue(), intWrapper.value); + } + + List negativeInputs = + Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper)); + } + } + + @Test + public void testToLong() throws IOException { + Map inputToExpectedOutput = new HashMap<>(); + inputToExpectedOutput.put("1", 1L); + inputToExpectedOutput.put("+1", 1L); + inputToExpectedOutput.put("-1", -1L); + inputToExpectedOutput.put("0", 0L); + inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L); + inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE); + inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE); + + Random rand = new Random(); + for (int i = 0; i < 10; i++) { + long value = rand.nextLong(); + inputToExpectedOutput.put(String.valueOf(value), value); + } + + LongWrapper wrapper = new LongWrapper(); + for (Map.Entry entry : inputToExpectedOutput.entrySet()) { + assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper)); + assertEquals((long) entry.getValue(), wrapper.value); + } + + List negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", + "1234567890123456789012345678901234"); + + for (String negativeInput : negativeInputs) { + assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper)); + } + } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala index 78aa5c40010cc..c98b87148e404 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskDescription.scala @@ -19,6 +19,7 @@ package org.apache.spark.scheduler import java.io.{DataInputStream, DataOutputStream} import java.nio.ByteBuffer +import java.nio.charset.StandardCharsets import java.util.Properties import scala.collection.JavaConverters._ @@ -86,7 +87,10 @@ private[spark] object TaskDescription { dataOut.writeInt(taskDescription.properties.size()) taskDescription.properties.asScala.foreach { case (key, value) => dataOut.writeUTF(key) - dataOut.writeUTF(value) + // SPARK-19796 -- writeUTF doesn't work for long strings, which can happen for property values + val bytes = value.getBytes(StandardCharsets.UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) } // Write the task. The task is already serialized, so write it directly to the byte buffer. @@ -124,7 +128,11 @@ private[spark] object TaskDescription { val properties = new Properties() val numProperties = dataIn.readInt() for (i <- 0 until numProperties) { - properties.setProperty(dataIn.readUTF(), dataIn.readUTF()) + val key = dataIn.readUTF() + val valueLength = dataIn.readInt() + val valueBytes = new Array[Byte](valueLength) + dataIn.readFully(valueBytes) + properties.setProperty(key, new String(valueBytes, StandardCharsets.UTF_8)) } // Create a sub-buffer for the serialized task into its own buffer (to be deserialized later). diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala index 9f1fe0515732e..97487ce1d2ca8 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskDescriptionSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler +import java.io.{ByteArrayOutputStream, DataOutputStream, UTFDataFormatException} import java.nio.ByteBuffer import java.util.Properties @@ -36,6 +37,21 @@ class TaskDescriptionSuite extends SparkFunSuite { val originalProperties = new Properties() originalProperties.put("property1", "18") originalProperties.put("property2", "test value") + // SPARK-19796 -- large property values (like a large job description for a long sql query) + // can cause problems for DataOutputStream, make sure we handle correctly + val sb = new StringBuilder() + (0 to 10000).foreach(_ => sb.append("1234567890")) + val largeString = sb.toString() + originalProperties.put("property3", largeString) + // make sure we've got a good test case + intercept[UTFDataFormatException] { + val out = new DataOutputStream(new ByteArrayOutputStream()) + try { + out.writeUTF(largeString) + } finally { + out.close() + } + } // Create a dummy byte buffer for the task. val taskBuffer = ByteBuffer.wrap(Array[Byte](1, 2, 3, 4)) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index ccede34b8cb4d..75dc04038debc 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -489,12 +489,12 @@ class BlockManagerProactiveReplicationSuite extends BlockManagerReplicationBehav Thread.sleep(200) } - // giving enough time for replication complete and locks released - Thread.sleep(500) - - val newLocations = master.getLocations(blockId).toSet + val newLocations = eventually(timeout(5 seconds), interval(10 millis)) { + val _newLocations = master.getLocations(blockId).toSet + assert(_newLocations.size === replicationFactor) + _newLocations + } logInfo(s"New locations : $newLocations") - assert(newLocations.size === replicationFactor) // there should only be one common block manager between initial and new locations assert(newLocations.intersect(blockLocations.toSet).size === 1) diff --git a/docs/ml-features.md b/docs/ml-features.md index 57605bafbf4c3..dad1c6db18f8b 100644 --- a/docs/ml-features.md +++ b/docs/ml-features.md @@ -503,6 +503,7 @@ for more details on the API. `StringIndexer` encodes a string column of labels to a column of label indices. The indices are in `[0, numLabels)`, ordered by label frequencies, so the most frequent label gets index `0`. +The unseen labels will be put at index numLabels if user chooses to keep them. If the input column is numeric, we cast it to string and index the string values. When downstream pipeline components such as `Estimator` or `Transformer` make use of this string-indexed label, you must set the input @@ -542,12 +543,13 @@ column, we should get the following: "a" gets index `0` because it is the most frequent, followed by "c" with index `1` and "b" with index `2`. -Additionally, there are two strategies regarding how `StringIndexer` will handle +Additionally, there are three strategies regarding how `StringIndexer` will handle unseen labels when you have fit a `StringIndexer` on one dataset and then use it to transform another: - throw an exception (which is the default) - skip the row containing the unseen label entirely +- put unseen labels in a special additional bucket, at index numLabels **Examples** @@ -561,6 +563,7 @@ Let's go back to our previous example but this time reuse our previously defined 1 | b 2 | c 3 | d + 4 | e ~~~~ If you've not set how `StringIndexer` handles unseen labels or set it to @@ -576,7 +579,22 @@ will be generated: 2 | c | 1.0 ~~~~ -Notice that the row containing "d" does not appear. +Notice that the rows containing "d" or "e" do not appear. + +If you call `setHandleInvalid("keep")`, the following dataset +will be generated: + +~~~~ + id | category | categoryIndex +----|----------|--------------- + 0 | a | 0.0 + 1 | b | 2.0 + 2 | c | 1.0 + 3 | d | 3.0 + 4 | e | 3.0 +~~~~ + +Notice that the rows containing "d" or "e" are mapped to index "3.0"
diff --git a/docs/quick-start.md b/docs/quick-start.md index aa4319a23325c..b88ae5f6bb313 100644 --- a/docs/quick-start.md +++ b/docs/quick-start.md @@ -10,12 +10,13 @@ description: Quick start tutorial for Spark SPARK_VERSION_SHORT This tutorial provides a quick introduction to using Spark. We will first introduce the API through Spark's interactive shell (in Python or Scala), then show how to write applications in Java, Scala, and Python. -See the [programming guide](programming-guide.html) for a more complete reference. To follow along with this guide, first download a packaged release of Spark from the [Spark website](http://spark.apache.org/downloads.html). Since we won't be using HDFS, you can download a package for any version of Hadoop. +Note that, before Spark 2.0, the main programming interface of Spark was the Resilient Distributed Dataset (RDD). After Spark 2.0, RDDs are replaced by Dataset, which is strongly-typed like an RDD, but with richer optimizations under the hood. The RDD interface is still supported, and you can get a more complete reference at the [RDD programming guide](rdd-programming-guide.html). However, we highly recommend you to switch to use Dataset, which has better performance than RDD. See the [SQL programming guide](sql-programming-guide.html) to get more information about Dataset. + # Interactive Analysis with the Spark Shell ## Basics @@ -29,28 +30,28 @@ or Python. Start it by running the following in the Spark directory: ./bin/spark-shell -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Let's make a new Dataset from the text of the README file in the Spark source directory: {% highlight scala %} -scala> val textFile = sc.textFile("README.md") -textFile: org.apache.spark.rdd.RDD[String] = README.md MapPartitionsRDD[1] at textFile at :25 +scala> val textFile = spark.read.textFile("README.md") +textFile: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from Dataset directly, by calling some actions, or transform the Dataset to get a new one. For more details, please read the _[API doc](api/scala/index.html#org.apache.spark.sql.Dataset)_. {% highlight scala %} -scala> textFile.count() // Number of items in this RDD +scala> textFile.count() // Number of items in this Dataset res0: Long = 126 // May be different from yours as README.md will change over time, similar to other outputs -scala> textFile.first() // First item in this RDD +scala> textFile.first() // First item in this Dataset res1: String = # Apache Spark {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this Dataset to a new one. We call `filter` to return a new Dataset with a subset of the items in the file. {% highlight scala %} scala> val linesWithSpark = textFile.filter(line => line.contains("Spark")) -linesWithSpark: org.apache.spark.rdd.RDD[String] = MapPartitionsRDD[2] at filter at :27 +linesWithSpark: org.apache.spark.sql.Dataset[String] = [value: string] {% endhighlight %} We can chain together transformations and actions: @@ -65,32 +66,32 @@ res3: Long = 15 ./bin/pyspark -Spark's primary abstraction is a distributed collection of items called a Resilient Distributed Dataset (RDD). RDDs can be created from Hadoop InputFormats (such as HDFS files) or by transforming other RDDs. Let's make a new RDD from the text of the README file in the Spark source directory: +Spark's primary abstraction is a distributed collection of items called a Dataset. Datasets can be created from Hadoop InputFormats (such as HDFS files) or by transforming other Datasets. Due to Python's dynamic nature, we don't need the Dataset to be strongly-typed in Python. As a result, all Datasets in Python are Dataset[Row], and we call it `DataFrame` to be consistent with the data frame concept in Pandas and R. Let's make a new DataFrame from the text of the README file in the Spark source directory: {% highlight python %} ->>> textFile = sc.textFile("README.md") +>>> textFile = spark.read.text("README.md") {% endhighlight %} -RDDs have _[actions](programming-guide.html#actions)_, which return values, and _[transformations](programming-guide.html#transformations)_, which return pointers to new RDDs. Let's start with a few actions: +You can get values from DataFrame directly, by calling some actions, or transform the DataFrame to get a new one. For more details, please read the _[API doc](api/python/index.html#pyspark.sql.DataFrame)_. {% highlight python %} ->>> textFile.count() # Number of items in this RDD +>>> textFile.count() # Number of rows in this DataFrame 126 ->>> textFile.first() # First item in this RDD -u'# Apache Spark' +>>> textFile.first() # First row in this DataFrame +Row(value=u'# Apache Spark') {% endhighlight %} -Now let's use a transformation. We will use the [`filter`](programming-guide.html#transformations) transformation to return a new RDD with a subset of the items in the file. +Now let's transform this DataFrame to a new one. We call `filter` to return a new DataFrame with a subset of the lines in the file. {% highlight python %} ->>> linesWithSpark = textFile.filter(lambda line: "Spark" in line) +>>> linesWithSpark = textFile.filter(textFile.value.contains("Spark")) {% endhighlight %} We can chain together transformations and actions: {% highlight python %} ->>> textFile.filter(lambda line: "Spark" in line).count() # How many lines contain "Spark"? +>>> textFile.filter(textFile.value.contains("Spark")).count() # How many lines contain "Spark"? 15 {% endhighlight %} @@ -98,8 +99,8 @@ We can chain together transformations and actions:
-## More on RDD Operations -RDD actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words: +## More on Dataset Operations +Dataset actions and transformations can be used for more complex computations. Let's say we want to find the line with the most words:
@@ -109,7 +110,7 @@ scala> textFile.map(line => line.split(" ").size).reduce((a, b) => if (a > b) a res4: Long = 15 {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: +This first maps a line to an integer value, creating a new Dataset. `reduce` is called on that Dataset to find the largest word count. The arguments to `map` and `reduce` are Scala function literals (closures), and can use any language feature or Scala/Java library. For example, we can easily call functions declared elsewhere. We'll use `Math.max()` function to make this code easier to understand: {% highlight scala %} scala> import java.lang.Math @@ -122,11 +123,11 @@ res5: Int = 15 One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight scala %} -scala> val wordCounts = textFile.flatMap(line => line.split(" ")).map(word => (word, 1)).reduceByKey((a, b) => a + b) -wordCounts: org.apache.spark.rdd.RDD[(String, Int)] = ShuffledRDD[8] at reduceByKey at :28 +scala> val wordCounts = textFile.flatMap(line => line.split(" ")).groupByKey(identity).count() +wordCounts: org.apache.spark.sql.Dataset[(String, Long)] = [value: string, count(1): bigint] {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (String, Int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we call `flatMap` to transform a Dataset of lines to a Dataset of words, and then combine `groupByKey` and `count` to compute the per-word counts in the file as a Dataset of (String, Long) pairs. To collect the word counts in our shell, we can call `collect`: {% highlight scala %} scala> wordCounts.collect() @@ -137,37 +138,24 @@ res6: Array[(String, Int)] = Array((means,1), (under,2), (this,3), (Because,1),
{% highlight python %} ->>> textFile.map(lambda line: len(line.split())).reduce(lambda a, b: a if (a > b) else b) -15 +>>> from pyspark.sql.functions import * +>>> textFile.select(size(split(textFile.value, "\s+")).name("numWords")).agg(max(col("numWords"))).collect() +[Row(max(numWords)=15)] {% endhighlight %} -This first maps a line to an integer value, creating a new RDD. `reduce` is called on that RDD to find the largest line count. The arguments to `map` and `reduce` are Python [anonymous functions (lambdas)](https://docs.python.org/2/reference/expressions.html#lambda), -but we can also pass any top-level Python function we want. -For example, we'll define a `max` function to make this code easier to understand: - -{% highlight python %} ->>> def max(a, b): -... if a > b: -... return a -... else: -... return b -... - ->>> textFile.map(lambda line: len(line.split())).reduce(max) -15 -{% endhighlight %} +This first maps a line to an integer value and aliases it as "numWords", creating a new DataFrame. `agg` is called on that DataFrame to find the largest word count. The arguments to `select` and `agg` are both _[Column](api/python/index.html#pyspark.sql.Column)_, we can use `df.colName` to get a column from a DataFrame. We can also import pyspark.sql.functions, which provides a lot of convenient functions to build a new Column from an old one. One common data flow pattern is MapReduce, as popularized by Hadoop. Spark can implement MapReduce flows easily: {% highlight python %} ->>> wordCounts = textFile.flatMap(lambda line: line.split()).map(lambda word: (word, 1)).reduceByKey(lambda a, b: a+b) +>>> wordCounts = textFile.select(explode(split(textFile.value, "\s+")).as("word")).groupBy("word").count() {% endhighlight %} -Here, we combined the [`flatMap`](programming-guide.html#transformations), [`map`](programming-guide.html#transformations), and [`reduceByKey`](programming-guide.html#transformations) transformations to compute the per-word counts in the file as an RDD of (string, int) pairs. To collect the word counts in our shell, we can use the [`collect`](programming-guide.html#actions) action: +Here, we use the `explode` function in `select`, to transfrom a Dataset of lines to a Dataset of words, and then combine `groupBy` and `count` to compute the per-word counts in the file as a DataFrame of 2 columns: "word" and "count". To collect the word counts in our shell, we can call `collect`: {% highlight python %} >>> wordCounts.collect() -[(u'and', 9), (u'A', 1), (u'webpage', 1), (u'README', 1), (u'Note', 1), (u'"local"', 1), (u'variable', 1), ...] +[Row(word=u'online', count=1), Row(word=u'graphs', count=1), ...] {% endhighlight %}
@@ -181,7 +169,7 @@ Spark also supports pulling data sets into a cluster-wide in-memory cache. This {% highlight scala %} scala> linesWithSpark.cache() -res7: linesWithSpark.type = MapPartitionsRDD[2] at filter at :27 +res7: linesWithSpark.type = [value: string] scala> linesWithSpark.count() res8: Long = 15 @@ -193,7 +181,7 @@ res9: Long = 15 It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/spark-shell` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -211,7 +199,7 @@ a cluster, as described in the [programming guide](programming-guide.html#initia It may seem silly to use Spark to explore and cache a 100-line text file. The interesting part is that these same functions can be used on very large data sets, even when they are striped across tens or hundreds of nodes. You can also do this interactively by connecting `bin/pyspark` to -a cluster, as described in the [programming guide](programming-guide.html#initializing-spark). +a cluster, as described in the [RDD programming guide](rdd-programming-guide.html#using-the-shell).
@@ -228,20 +216,17 @@ named `SimpleApp.scala`: {% highlight scala %} /* SimpleApp.scala */ -import org.apache.spark.SparkContext -import org.apache.spark.SparkContext._ -import org.apache.spark.SparkConf +import org.apache.spark.sql.SparkSession object SimpleApp { def main(args: Array[String]) { val logFile = "YOUR_SPARK_HOME/README.md" // Should be some file on your system - val conf = new SparkConf().setAppName("Simple Application") - val sc = new SparkContext(conf) - val logData = sc.textFile(logFile, 2).cache() + val spark = SparkSession.builder.appName("Simple Application").getOrCreate() + val logData = spark.read.textFile(logFile).cache() val numAs = logData.filter(line => line.contains("a")).count() val numBs = logData.filter(line => line.contains("b")).count() println(s"Lines with a: $numAs, Lines with b: $numBs") - sc.stop() + spark.stop() } } {% endhighlight %} @@ -251,16 +236,13 @@ Subclasses of `scala.App` may not work correctly. This program just counts the number of lines containing 'a' and the number containing 'b' in the Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is -installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkContext, -we initialize a SparkContext as part of the program. +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. -We pass the SparkContext constructor a -[SparkConf](api/scala/index.html#org.apache.spark.SparkConf) -object which contains information about our -application. +We call `SparkSession.builder` to construct a [[SparkSession]], then set the application name, and finally call `getOrCreate` to get the [[SparkSession]] instance. -Our application depends on the Spark API, so we'll also include an sbt configuration file, -`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that +Our application depends on the Spark API, so we'll also include an sbt configuration file, +`build.sbt`, which explains that Spark is a dependency. This file also adds a repository that Spark depends on: {% highlight scala %} @@ -270,7 +252,7 @@ version := "1.0" scalaVersion := "{{site.SCALA_VERSION}}" -libraryDependencies += "org.apache.spark" %% "spark-core" % "{{site.SPARK_VERSION}}" +libraryDependencies += "org.apache.spark" %% "spark-sql" % "{{site.SPARK_VERSION}}" {% endhighlight %} For sbt to work correctly, we'll need to layout `SimpleApp.scala` and `build.sbt` @@ -309,34 +291,28 @@ We'll create a very simple Spark application, `SimpleApp.java`: {% highlight java %} /* SimpleApp.java */ -import org.apache.spark.api.java.*; -import org.apache.spark.SparkConf; -import org.apache.spark.api.java.function.Function; +import org.apache.spark.sql.SparkSession; public class SimpleApp { public static void main(String[] args) { String logFile = "YOUR_SPARK_HOME/README.md"; // Should be some file on your system - SparkConf conf = new SparkConf().setAppName("Simple Application"); - JavaSparkContext sc = new JavaSparkContext(conf); - JavaRDD logData = sc.textFile(logFile).cache(); + SparkSession spark = SparkSession.builder().appName("Simple Application").getOrCreate(); + Dataset logData = spark.read.textFile(logFile).cache(); long numAs = logData.filter(s -> s.contains("a")).count(); long numBs = logData.filter(s -> s.contains("b")).count(); System.out.println("Lines with a: " + numAs + ", lines with b: " + numBs); - - sc.stop(); + + spark.stop(); } } {% endhighlight %} -This program just counts the number of lines containing 'a' and the number containing 'b' in a text -file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala example, we initialize a SparkContext, though we use the special -`JavaSparkContext` class to get a Java-friendly one. We also create RDDs (represented by -`JavaRDD`) and run transformations on them. Finally, we pass functions to Spark by creating classes -that extend `spark.api.java.function.Function`. The -[Spark programming guide](programming-guide.html) describes these differences in more detail. +This program just counts the number of lines containing 'a' and the number containing 'b' in the +Spark README. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is +installed. Unlike the earlier examples with the Spark shell, which initializes its own SparkSession, +we initialize a SparkSession as part of the program. To build the program, we also write a Maven `pom.xml` file that lists Spark as a dependency. Note that Spark artifacts are tagged with a Scala version. @@ -352,7 +328,7 @@ Note that Spark artifacts are tagged with a Scala version. org.apache.spark - spark-core_{{site.SCALA_BINARY_VERSION}} + spark-sql_{{site.SCALA_BINARY_VERSION}} {{site.SPARK_VERSION}} @@ -395,27 +371,25 @@ As an example, we'll create a simple Spark application, `SimpleApp.py`: {% highlight python %} """SimpleApp.py""" -from pyspark import SparkContext +from pyspark.sql import SparkSession logFile = "YOUR_SPARK_HOME/README.md" # Should be some file on your system -sc = SparkContext("local", "Simple App") -logData = sc.textFile(logFile).cache() +spark = SparkSession.builder().appName(appName).master(master).getOrCreate() +logData = spark.read.text(logFile).cache() -numAs = logData.filter(lambda s: 'a' in s).count() -numBs = logData.filter(lambda s: 'b' in s).count() +numAs = logData.filter(logData.value.contains('a')).count() +numBs = logData.filter(logData.value.contains('b')).count() print("Lines with a: %i, lines with b: %i" % (numAs, numBs)) -sc.stop() +spark.stop() {% endhighlight %} This program just counts the number of lines containing 'a' and the number containing 'b' in a text file. Note that you'll need to replace YOUR_SPARK_HOME with the location where Spark is installed. -As with the Scala and Java examples, we use a SparkContext to create RDDs. -We can pass Python functions to Spark, which are automatically serialized along with any variables -that they reference. +As with the Scala and Java examples, we use a SparkSession to create Datasets. For applications that use custom classes or third-party libraries, we can also add code dependencies to `spark-submit` through its `--py-files` argument by packaging them into a .zip file (see `spark-submit --help` for details). @@ -438,8 +412,7 @@ Lines with a: 46, Lines with b: 23 # Where to Go from Here Congratulations on running your first Spark application! -* For an in-depth overview of the API, start with the [Spark programming guide](programming-guide.html), - or see "Programming Guides" menu for other components. +* For an in-depth overview of the API, start with the [RDD programming guide](rdd-programming-guide.html) and the [SQL programming guide](sql-programming-guide.html), or see "Programming Guides" menu for other components. * For running applications on a cluster, head to the [deployment overview](cluster-overview.html). * Finally, Spark includes several samples in the `examples` directory ([Scala]({{site.SPARK_GITHUB_URL}}/tree/master/examples/src/main/scala/org/apache/spark/examples), diff --git a/docs/programming-guide.md b/docs/rdd-programming-guide.md similarity index 99% rename from docs/programming-guide.md rename to docs/rdd-programming-guide.md index 6740dbe0014b4..cad9ff4e646e5 100644 --- a/docs/programming-guide.md +++ b/docs/rdd-programming-guide.md @@ -24,7 +24,7 @@ along with if you launch Spark's interactive shell -- either `bin/spark-shell` f
-Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} +Spark {{site.SPARK_VERSION}} is built and distributed to work with Scala {{site.SCALA_BINARY_VERSION}} by default. (Spark can be built to work with other versions of Scala, too.) To write applications in Scala, you will need to use a compatible Scala version (e.g. {{site.SCALA_BINARY_VERSION}}.X). @@ -76,10 +76,10 @@ In addition, if you wish to access an HDFS cluster, you need to add a dependency Finally, you need to import some Spark classes into your program. Add the following lines: -{% highlight scala %} -import org.apache.spark.api.java.JavaSparkContext -import org.apache.spark.api.java.JavaRDD -import org.apache.spark.SparkConf +{% highlight java %} +import org.apache.spark.api.java.JavaSparkContext; +import org.apache.spark.api.java.JavaRDD; +import org.apache.spark.SparkConf; {% endhighlight %}
@@ -244,13 +244,13 @@ use IPython, set the `PYSPARK_DRIVER_PYTHON` variable to `ipython` when running $ PYSPARK_DRIVER_PYTHON=ipython ./bin/pyspark {% endhighlight %} -To use the Jupyter notebook (previously known as the IPython notebook), +To use the Jupyter notebook (previously known as the IPython notebook), {% highlight bash %} $ PYSPARK_DRIVER_PYTHON=jupyter ./bin/pyspark {% endhighlight %} -You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. +You can customize the `ipython` or `jupyter` commands by setting `PYSPARK_DRIVER_PYTHON_OPTS`. After the Jupyter Notebook server is launched, you can create a new "Python 2" notebook from the "Files" tab. Inside the notebook, you can input the command `%pylab inline` as part of @@ -811,7 +811,7 @@ The variables within the closure sent to each executor are now copies and thus, In local mode, in some circumstances the `foreach` function will actually execute within the same JVM as the driver and will reference the same original **counter**, and may actually update it. -To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. +To ensure well-defined behavior in these sorts of scenarios one should use an [`Accumulator`](#accumulators). Accumulators in Spark are used specifically to provide a mechanism for safely updating a variable when execution is split up across worker nodes in a cluster. The Accumulators section of this guide discusses these in more detail. In general, closures - constructs like loops or locally defined methods, should not be used to mutate some global state. Spark does not define or guarantee the behavior of mutations to objects referenced from outside of closures. Some code that does this may work in local mode, but that's just by accident and such code will not behave as expected in distributed mode. Use an Accumulator instead if some global aggregation is needed. @@ -1230,8 +1230,8 @@ storage levels is: -**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, -so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, +**Note:** *In Python, stored objects will always be serialized with the [Pickle](https://docs.python.org/2/library/pickle.html) library, +so it does not matter whether you choose a serialized level. The available storage levels in Python include `MEMORY_ONLY`, `MEMORY_ONLY_2`, `MEMORY_AND_DISK`, `MEMORY_AND_DISK_2`, `DISK_ONLY`, and `DISK_ONLY_2`.* Spark also automatically persists some intermediate data in shuffle operations (e.g. `reduceByKey`), even without users calling `persist`. This is done to avoid recomputing the entire input if a node fails during the shuffle. We still recommend users call `persist` on the resulting RDD if they plan to reuse it. @@ -1346,7 +1346,7 @@ As a user, you can create named or unnamed accumulators. As seen in the image be Accumulators in the Spark UI

-Tracking accumulators in the UI can be useful for understanding the progress of +Tracking accumulators in the UI can be useful for understanding the progress of running stages (NOTE: this is not yet supported in Python).
@@ -1355,7 +1355,7 @@ running stages (NOTE: this is not yet supported in Python). A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: @@ -1409,7 +1409,7 @@ Note that, when programmers define their own type of AccumulatorV2, the resultin A numeric accumulator can be created by calling `SparkContext.longAccumulator()` or `SparkContext.doubleAccumulator()` to accumulate values of type Long or Double, respectively. Tasks running on a cluster can then add to it using -the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, +the `add` method. However, they cannot read its value. Only the driver program can read the accumulator's value, using its `value` method. The code below shows an accumulator being used to add up the elements of an array: diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala new file mode 100644 index 0000000000000..08914d82fffdd --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSink.scala @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.execution.streaming.Sink + +private[kafka010] class KafkaSink( + sqlContext: SQLContext, + executorKafkaParams: ju.Map[String, Object], + topic: Option[String]) extends Sink with Logging { + @volatile private var latestBatchId = -1L + + override def toString(): String = "KafkaSink" + + override def addBatch(batchId: Long, data: DataFrame): Unit = { + if (batchId <= latestBatchId) { + logInfo(s"Skipping already committed batch $batchId") + } else { + KafkaWriter.write(sqlContext.sparkSession, + data.queryExecution, executorKafkaParams, topic) + latestBatchId = batchId + } + } +} diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala index 6a7456719875f..febe3c217122a 100644 --- a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaSourceProvider.scala @@ -23,12 +23,14 @@ import java.util.UUID import scala.collection.JavaConverters._ import org.apache.kafka.clients.consumer.ConsumerConfig -import org.apache.kafka.common.serialization.ByteArrayDeserializer +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.{ByteArrayDeserializer, ByteArraySerializer} import org.apache.spark.internal.Logging -import org.apache.spark.sql.SQLContext -import org.apache.spark.sql.execution.streaming.Source +import org.apache.spark.sql.{AnalysisException, DataFrame, SaveMode, SQLContext} +import org.apache.spark.sql.execution.streaming.{Sink, Source} import org.apache.spark.sql.sources._ +import org.apache.spark.sql.streaming.OutputMode import org.apache.spark.sql.types.StructType /** @@ -36,8 +38,12 @@ import org.apache.spark.sql.types.StructType * IllegalArgumentException when the Kafka Dataset is created, so that it can catch * missing options even before the query is started. */ -private[kafka010] class KafkaSourceProvider extends DataSourceRegister with StreamSourceProvider - with RelationProvider with Logging { +private[kafka010] class KafkaSourceProvider extends DataSourceRegister + with StreamSourceProvider + with StreamSinkProvider + with RelationProvider + with CreatableRelationProvider + with Logging { import KafkaSourceProvider._ override def shortName(): String = "kafka" @@ -152,6 +158,72 @@ private[kafka010] class KafkaSourceProvider extends DataSourceRegister with Stre endingRelationOffsets) } + override def createSink( + sqlContext: SQLContext, + parameters: Map[String, String], + partitionColumns: Seq[String], + outputMode: OutputMode): Sink = { + val defaultTopic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + new KafkaSink(sqlContext, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), defaultTopic) + } + + override def createRelation( + outerSQLContext: SQLContext, + mode: SaveMode, + parameters: Map[String, String], + data: DataFrame): BaseRelation = { + mode match { + case SaveMode.Overwrite | SaveMode.Ignore => + throw new AnalysisException(s"Save mode $mode not allowed for Kafka. " + + s"Allowed save modes are ${SaveMode.Append} and " + + s"${SaveMode.ErrorIfExists} (default).") + case _ => // good + } + val topic = parameters.get(TOPIC_OPTION_KEY).map(_.trim) + val specifiedKafkaParams = kafkaParamsForProducer(parameters) + KafkaWriter.write(outerSQLContext.sparkSession, data.queryExecution, + new ju.HashMap[String, Object](specifiedKafkaParams.asJava), topic) + + /* This method is suppose to return a relation that reads the data that was written. + * We cannot support this for Kafka. Therefore, in order to make things consistent, + * we return an empty base relation. + */ + new BaseRelation { + override def sqlContext: SQLContext = unsupportedException + override def schema: StructType = unsupportedException + override def needConversion: Boolean = unsupportedException + override def sizeInBytes: Long = unsupportedException + override def unhandledFilters(filters: Array[Filter]): Array[Filter] = unsupportedException + private def unsupportedException = + throw new UnsupportedOperationException("BaseRelation from Kafka write " + + "operation is not usable.") + } + } + + private def kafkaParamsForProducer(parameters: Map[String, String]): Map[String, String] = { + val caseInsensitiveParams = parameters.map { case (k, v) => (k.toLowerCase, v) } + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}")) { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG}' is not supported as keys " + + "are serialized with ByteArraySerializer.") + } + + if (caseInsensitiveParams.contains(s"kafka.${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}")) + { + throw new IllegalArgumentException( + s"Kafka option '${ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG}' is not supported as " + + "value are serialized with ByteArraySerializer.") + } + parameters + .keySet + .filter(_.toLowerCase.startsWith("kafka.")) + .map { k => k.drop(6).toString -> parameters(k) } + .toMap + (ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName, + ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG -> classOf[ByteArraySerializer].getName) + } + private def kafkaParamsForDriver(specifiedKafkaParams: Map[String, String]) = ConfigUpdater("source", specifiedKafkaParams) .set(ConsumerConfig.KEY_DESERIALIZER_CLASS_CONFIG, deserClassName) @@ -381,6 +453,7 @@ private[kafka010] object KafkaSourceProvider { private val STARTING_OFFSETS_OPTION_KEY = "startingoffsets" private val ENDING_OFFSETS_OPTION_KEY = "endingoffsets" private val FAIL_ON_DATA_LOSS_OPTION_KEY = "failondataloss" + val TOPIC_OPTION_KEY = "topic" private val deserClassName = classOf[ByteArrayDeserializer].getName } diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala new file mode 100644 index 0000000000000..6e160cbe2db52 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriteTask.scala @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.kafka.clients.producer.{KafkaProducer, _} +import org.apache.kafka.common.serialization.ByteArraySerializer + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, Cast, Literal, UnsafeProjection} +import org.apache.spark.sql.types.{BinaryType, StringType} + +/** + * A simple trait for writing out data in a single Spark task, without any concerns about how + * to commit or abort tasks. Exceptions thrown by the implementation of this class will + * automatically trigger task aborts. + */ +private[kafka010] class KafkaWriteTask( + producerConfiguration: ju.Map[String, Object], + inputSchema: Seq[Attribute], + topic: Option[String]) { + // used to synchronize with Kafka callbacks + @volatile private var failedWrite: Exception = null + private val projection = createProjection + private var producer: KafkaProducer[Array[Byte], Array[Byte]] = _ + + /** + * Writes key value data out to topics. + */ + def execute(iterator: Iterator[InternalRow]): Unit = { + producer = new KafkaProducer[Array[Byte], Array[Byte]](producerConfiguration) + while (iterator.hasNext && failedWrite == null) { + val currentRow = iterator.next() + val projectedRow = projection(currentRow) + val topic = projectedRow.getUTF8String(0) + val key = projectedRow.getBinary(1) + val value = projectedRow.getBinary(2) + if (topic == null) { + throw new NullPointerException(s"null topic present in the data. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a default topic.") + } + val record = new ProducerRecord[Array[Byte], Array[Byte]](topic.toString, key, value) + val callback = new Callback() { + override def onCompletion(recordMetadata: RecordMetadata, e: Exception): Unit = { + if (failedWrite == null && e != null) { + failedWrite = e + } + } + } + producer.send(record, callback) + } + } + + def close(): Unit = { + if (producer != null) { + checkForErrors + producer.close() + checkForErrors + producer = null + } + } + + private def createProjection: UnsafeProjection = { + val topicExpression = topic.map(Literal(_)).orElse { + inputSchema.find(_.name == KafkaWriter.TOPIC_ATTRIBUTE_NAME) + }.getOrElse { + throw new IllegalStateException(s"topic option required when no " + + s"'${KafkaWriter.TOPIC_ATTRIBUTE_NAME}' attribute is present") + } + topicExpression.dataType match { + case StringType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t. ${KafkaWriter.TOPIC_ATTRIBUTE_NAME} " + + s"must be a ${StringType}") + } + val keyExpression = inputSchema.find(_.name == KafkaWriter.KEY_ATTRIBUTE_NAME) + .getOrElse(Literal(null, BinaryType)) + keyExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.KEY_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + val valueExpression = inputSchema + .find(_.name == KafkaWriter.VALUE_ATTRIBUTE_NAME).getOrElse( + throw new IllegalStateException(s"Required attribute " + + s"'${KafkaWriter.VALUE_ATTRIBUTE_NAME}' not found") + ) + valueExpression.dataType match { + case StringType | BinaryType => // good + case t => + throw new IllegalStateException(s"${KafkaWriter.VALUE_ATTRIBUTE_NAME} " + + s"attribute unsupported type $t") + } + UnsafeProjection.create( + Seq(topicExpression, Cast(keyExpression, BinaryType), + Cast(valueExpression, BinaryType)), inputSchema) + } + + private def checkForErrors: Unit = { + if (failedWrite != null) { + throw failedWrite + } + } +} + diff --git a/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala new file mode 100644 index 0000000000000..a637d52c933a3 --- /dev/null +++ b/external/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala @@ -0,0 +1,97 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.{util => ju} + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.{AnalysisException, SparkSession} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.execution.{QueryExecution, SQLExecution} +import org.apache.spark.sql.types.{BinaryType, StringType} +import org.apache.spark.util.Utils + +/** + * The [[KafkaWriter]] class is used to write data from a batch query + * or structured streaming query, given by a [[QueryExecution]], to Kafka. + * The data is assumed to have a value column, and an optional topic and key + * columns. If the topic column is missing, then the topic must come from + * the 'topic' configuration option. If the key column is missing, then a + * null valued key field will be added to the + * [[org.apache.kafka.clients.producer.ProducerRecord]]. + */ +private[kafka010] object KafkaWriter extends Logging { + val TOPIC_ATTRIBUTE_NAME: String = "topic" + val KEY_ATTRIBUTE_NAME: String = "key" + val VALUE_ATTRIBUTE_NAME: String = "value" + + override def toString: String = "KafkaWriter" + + def validateQuery( + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + schema.find(_.name == TOPIC_ATTRIBUTE_NAME).getOrElse( + if (topic == None) { + throw new AnalysisException(s"topic option required when no " + + s"'$TOPIC_ATTRIBUTE_NAME' attribute is present. Use the " + + s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.") + } else { + Literal(topic.get, StringType) + } + ).dataType match { + case StringType => // good + case _ => + throw new AnalysisException(s"Topic type must be a String") + } + schema.find(_.name == KEY_ATTRIBUTE_NAME).getOrElse( + Literal(null, StringType) + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$KEY_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + schema.find(_.name == VALUE_ATTRIBUTE_NAME).getOrElse( + throw new AnalysisException(s"Required attribute '$VALUE_ATTRIBUTE_NAME' not found") + ).dataType match { + case StringType | BinaryType => // good + case _ => + throw new AnalysisException(s"$VALUE_ATTRIBUTE_NAME attribute type " + + s"must be a String or BinaryType") + } + } + + def write( + sparkSession: SparkSession, + queryExecution: QueryExecution, + kafkaParameters: ju.Map[String, Object], + topic: Option[String] = None): Unit = { + val schema = queryExecution.logical.output + validateQuery(queryExecution, kafkaParameters, topic) + SQLExecution.withNewExecutionId(sparkSession, queryExecution) { + queryExecution.toRdd.foreachPartition { iter => + val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic) + Utils.tryWithSafeFinally(block = writeTask.execute(iter))( + finallyBlock = writeTask.close()) + } + } + } +} diff --git a/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala new file mode 100644 index 0000000000000..490535623cb36 --- /dev/null +++ b/external/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaSinkSuite.scala @@ -0,0 +1,412 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.kafka010 + +import java.util.concurrent.atomic.AtomicInteger + +import org.apache.kafka.clients.producer.ProducerConfig +import org.apache.kafka.common.serialization.ByteArraySerializer +import org.scalatest.time.SpanSugar._ + +import org.apache.spark.SparkException +import org.apache.spark.sql._ +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, SpecificInternalRow, UnsafeProjection} +import org.apache.spark.sql.execution.streaming.MemoryStream +import org.apache.spark.sql.streaming._ +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.{BinaryType, DataType} + +class KafkaSinkSuite extends StreamTest with SharedSQLContext { + import testImplicits._ + + protected var testUtils: KafkaTestUtils = _ + + override val streamingTimeout = 30.seconds + + override def beforeAll(): Unit = { + super.beforeAll() + testUtils = new KafkaTestUtils( + withBrokerProps = Map("auto.create.topics.enable" -> "false")) + testUtils.setup() + } + + override def afterAll(): Unit = { + if (testUtils != null) { + testUtils.teardown() + testUtils = null + super.afterAll() + } + } + + test("batch - write to kafka") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq("1", "2", "3", "4", "5").map(v => (topic, v)).toDF("topic", "value") + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("topic", topic) + .save() + checkAnswer( + createKafkaReader(topic).selectExpr("CAST(value as STRING) value"), + Row("1") :: Row("2") :: Row("3") :: Row("4") :: Row("5") :: Nil) + } + + test("batch - null topic field value, and no topic option") { + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + val ex = intercept[SparkException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + "null topic present in the data")) + } + + test("batch - unsupported save modes") { + val topic = newTopic() + testUtils.createTopic(topic) + val df = Seq[(String, String)](null.asInstanceOf[String] -> "1").toDF("topic", "value") + + // Test bad save mode Ignore + var ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Ignore) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode ignore not allowed for kafka")) + + // Test bad save mode Overwrite + ex = intercept[AnalysisException] { + df.write + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .mode(SaveMode.Overwrite) + .save() + } + assert(ex.getMessage.toLowerCase.contains( + s"save mode overwrite not allowed for kafka")) + } + + test("streaming - write to kafka with topic field") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF(), + withTopic = None, + withOutputMode = Some(OutputMode.Append))( + withSelectExpr = s"'$topic' as topic", "value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + .map(_._2) + + try { + input.addData("1", "2", "3", "4", "5") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5) + input.addData("6", "7", "8", "9", "10") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10) + } finally { + writer.stop() + } + } + + test("streaming - write aggregation w/o topic field, with topic option") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key as STRING) key", "CAST(value as STRING) value") + .selectExpr("CAST(key as INT) key", "CAST(value as INT) value") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + test("streaming - aggregation with topic field and topic option") { + /* The purpose of this test is to ensure that the topic option + * overrides the topic field. We begin by writing some data that + * includes a topic field and value (e.g., 'foo') along with a topic + * option. Then when we read from the topic specified in the option + * we should see the data i.e., the data was written to the topic + * option, and not to the topic in the data e.g., foo + */ + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + val writer = createKafkaWriter( + input.toDF().groupBy("value").count(), + withTopic = Some(topic), + withOutputMode = Some(OutputMode.Update()))( + withSelectExpr = "'foo' as topic", + "CAST(value as STRING) key", "CAST(count as STRING) value") + + val reader = createKafkaReader(topic) + .selectExpr("CAST(key AS STRING)", "CAST(value AS STRING)") + .selectExpr("CAST(key AS INT)", "CAST(value AS INT)") + .as[(Int, Int)] + + try { + input.addData("1", "2", "2", "3", "3", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3)) + input.addData("1", "2", "3") + failAfter(streamingTimeout) { + writer.processAllAvailable() + } + checkDatasetUnorderly(reader, (1, 1), (2, 2), (3, 3), (1, 2), (2, 3), (3, 4)) + } finally { + writer.stop() + } + } + + + test("streaming - write data with bad schema") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + /* No topic field or topic option */ + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = "value as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage + .toLowerCase + .contains("topic option required when no 'topic' attribute is present")) + + try { + /* No value field */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "value as key" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("required attribute 'value' not found")) + } + + test("streaming - write data with valid schema but wrong types") { + val input = MemoryStream[String] + val topic = newTopic() + testUtils.createTopic(topic) + + var writer: StreamingQuery = null + var ex: Exception = null + try { + /* topic field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"CAST('1' as INT) as topic", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("topic type must be a string")) + + try { + /* value field wrong type */ + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "value attribute type must be a string or binarytype")) + + try { + ex = intercept[StreamingQueryException] { + /* key field wrong type */ + writer = createKafkaWriter(input.toDF())( + withSelectExpr = s"'$topic' as topic", "CAST(value as INT) as key", "value" + ) + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains( + "key attribute type must be a string or binarytype")) + } + + test("streaming - write to non-existing topic") { + val input = MemoryStream[String] + val topic = newTopic() + + var writer: StreamingQuery = null + var ex: Exception = null + try { + ex = intercept[StreamingQueryException] { + writer = createKafkaWriter(input.toDF(), withTopic = Some(topic))() + input.addData("1", "2", "3", "4", "5") + writer.processAllAvailable() + } + } finally { + writer.stop() + } + assert(ex.getMessage.toLowerCase.contains("job aborted")) + } + + test("streaming - exception on config serializer") { + val input = MemoryStream[String] + var writer: StreamingQuery = null + var ex: Exception = null + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.key.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'key.serializer' is not supported")) + + ex = intercept[IllegalArgumentException] { + writer = createKafkaWriter( + input.toDF(), + withOptions = Map("kafka.value.serializer" -> "foo"))() + } + assert(ex.getMessage.toLowerCase.contains( + "kafka option 'value.serializer' is not supported")) + } + + test("generic - write big data with small producer buffer") { + /* This test ensures that we understand the semantics of Kafka when + * is comes to blocking on a call to send when the send buffer is full. + * This test will configure the smallest possible producer buffer and + * indicate that we should block when it is full. Thus, no exception should + * be thrown in the case of a full buffer. + */ + val topic = newTopic() + testUtils.createTopic(topic, 1) + val options = new java.util.HashMap[String, Object] + options.put("bootstrap.servers", testUtils.brokerAddress) + options.put("buffer.memory", "16384") // min buffer size + options.put("block.on.buffer.full", "true") + options.put(ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + options.put(ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG, classOf[ByteArraySerializer].getName) + val inputSchema = Seq(AttributeReference("value", BinaryType)()) + val data = new Array[Byte](15000) // large value + val writeTask = new KafkaWriteTask(options, inputSchema, Some(topic)) + try { + val fieldTypes: Array[DataType] = Array(BinaryType) + val converter = UnsafeProjection.create(fieldTypes) + val row = new SpecificInternalRow(fieldTypes) + row.update(0, data) + val iter = Seq.fill(1000)(converter.apply(row)).iterator + writeTask.execute(iter) + } finally { + writeTask.close() + } + } + + private val topicId = new AtomicInteger(0) + + private def newTopic(): String = s"topic-${topicId.getAndIncrement()}" + + private def createKafkaReader(topic: String): DataFrame = { + spark.read + .format("kafka") + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .option("startingOffsets", "earliest") + .option("endingOffsets", "latest") + .option("subscribe", topic) + .load() + } + + private def createKafkaWriter( + input: DataFrame, + withTopic: Option[String] = None, + withOutputMode: Option[OutputMode] = None, + withOptions: Map[String, String] = Map[String, String]()) + (withSelectExpr: String*): StreamingQuery = { + var stream: DataStreamWriter[Row] = null + withTempDir { checkpointDir => + var df = input.toDF() + if (withSelectExpr.length > 0) { + df = df.selectExpr(withSelectExpr: _*) + } + stream = df.writeStream + .format("kafka") + .option("checkpointLocation", checkpointDir.getCanonicalPath) + .option("kafka.bootstrap.servers", testUtils.brokerAddress) + .queryName("kafkaStream") + withTopic.foreach(stream.option("topic", _)) + withOutputMode.foreach(stream.outputMode(_)) + withOptions.foreach(opt => stream.option(opt._1, opt._2)) + } + stream.start() + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index a503411b63612..810b02febbe77 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -17,6 +17,8 @@ package org.apache.spark.ml.feature +import scala.language.existentials + import org.apache.hadoop.fs.Path import org.apache.spark.SparkException @@ -24,7 +26,7 @@ import org.apache.spark.annotation.Since import org.apache.spark.ml.{Estimator, Model, Transformer} import org.apache.spark.ml.attribute.{Attribute, NominalAttribute} import org.apache.spark.ml.param._ -import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.functions._ @@ -34,8 +36,27 @@ import org.apache.spark.util.collection.OpenHashMap /** * Base trait for [[StringIndexer]] and [[StringIndexerModel]]. */ -private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol - with HasHandleInvalid { +private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { + + /** + * Param for how to handle unseen labels. Options are 'skip' (filter out rows with + * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional + * bucket, at index numLabels. + * Default: "error" + * @group param + */ + @Since("1.6.0") + val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + + "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + + "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + + "at index numLabels).", + ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) + + setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + + /** @group getParam */ + @Since("1.6.0") + def getHandleInvalid: String = $(handleInvalid) /** Validates and transforms the input schema. */ protected def validateAndTransformSchema(schema: StructType): StructType = { @@ -73,7 +94,6 @@ class StringIndexer @Since("1.4.0") ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -105,6 +125,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { + private[feature] val SKIP_UNSEEN_LABEL: String = "skip" + private[feature] val ERROR_UNSEEN_LABEL: String = "error" + private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val supportedHandleInvalids: Array[String] = + Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -144,7 +169,6 @@ class StringIndexerModel ( /** @group setParam */ @Since("1.6.0") def setHandleInvalid(value: String): this.type = set(handleInvalid, value) - setDefault(handleInvalid, "error") /** @group setParam */ @Since("1.4.0") @@ -163,25 +187,34 @@ class StringIndexerModel ( } transformSchema(dataset.schema, logging = true) - val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else { - throw new SparkException(s"Unseen label: $label.") - } + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case _ => labels } val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(labels).toMetadata() + .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. - val filteredDataset = getHandleInvalid match { - case "skip" => + val (filteredDataset, keepInvalid) = getHandleInvalid match { + case StringIndexer.SKIP_UNSEEN_LABEL => val filterer = udf { label: String => labelToIndex.contains(label) } - dataset.where(filterer(dataset($(inputCol)))) - case _ => dataset + (dataset.where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) } + + val indexer = udf { label: String => + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + } + } + filteredDataset.select(col("*"), indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 42e8a66a62b61..4ca062c0b5adf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -227,25 +227,50 @@ class Word2VecModel private[ml] ( /** * Find "num" number of words closest in similarity to the given word, not - * including the word itself. Returns a dataframe with the words and the - * cosine similarities between the synonyms and the given word. + * including the word itself. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine + * similarities between the synonyms and the given word vector. */ @Since("1.5.0") def findSynonyms(word: String, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(word, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(word, num)).toDF("word", "similarity") } /** - * Find "num" number of words whose vector representation most similar to the supplied vector. + * Find "num" number of words whose vector representation is most similar to the supplied vector. * If the supplied vector is the vector representation of a word in the model's vocabulary, - * that word will be in the results. Returns a dataframe with the words and the cosine + * that word will be in the results. + * @return a dataframe with columns "word" and "similarity" of the word and the cosine * similarities between the synonyms and the given word vector. */ @Since("2.0.0") def findSynonyms(vec: Vector, num: Int): DataFrame = { val spark = SparkSession.builder().getOrCreate() - spark.createDataFrame(wordVectors.findSynonyms(vec, num)).toDF("word", "similarity") + spark.createDataFrame(findSynonymsArray(vec, num)).toDF("word", "similarity") + } + + /** + * Find "num" number of words whose vector representation is most similar to the supplied vector. + * If the supplied vector is the vector representation of a word in the model's vocabulary, + * that word will be in the results. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(vec: Vector, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(vec, num) + } + + /** + * Find "num" number of words closest in similarity to the given word, not + * including the word itself. + * @return an array of the words and the cosine similarities between the synonyms given + * word vector. + */ + @Since("2.2.0") + def findSynonymsArray(word: String, num: Int): Array[(String, Double)] = { + wordVectors.findSynonyms(word, num) } /** @group setParam */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index 110764dc074f7..3be8b533ee3f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -66,7 +66,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam /** * Param for the power in the variance function of the Tweedie distribution which provides * the relationship between the variance and mean of the distribution. - * Only applicable for the Tweedie family. + * Only applicable to the Tweedie family. * (see * Tweedie Distribution (Wikipedia)) * Supported values: 0 and [1, Inf). @@ -79,7 +79,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam final val variancePower: DoubleParam = new DoubleParam(this, "variancePower", "The power in the variance function of the Tweedie distribution which characterizes " + "the relationship between the variance and mean of the distribution. " + - "Only applicable for the Tweedie family. Supported values: 0 and [1, Inf).", + "Only applicable to the Tweedie family. Supported values: 0 and [1, Inf).", (x: Double) => x >= 1.0 || x == 0.0) /** @group getParam */ @@ -106,7 +106,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam def getLink: String = $(link) /** - * Param for the index in the power link function. Only applicable for the Tweedie family. + * Param for the index in the power link function. Only applicable to the Tweedie family. * Note that link power 0, 1, -1 or 0.5 corresponds to the Log, Identity, Inverse or Sqrt * link, respectively. * When not set, this value defaults to 1 - [[variancePower]], which matches the R "statmod" @@ -116,7 +116,7 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam */ @Since("2.2.0") final val linkPower: DoubleParam = new DoubleParam(this, "linkPower", - "The index in the power link function. Only applicable for the Tweedie family.") + "The index in the power link function. Only applicable to the Tweedie family.") /** @group getParam */ @Since("2.2.0") diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala index 2364d43aaa0e2..531c8b07910fc 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala @@ -30,6 +30,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.SparkContext import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD +import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.Logging import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.{Loader, Saveable} @@ -314,6 +315,20 @@ class Word2Vec extends Serializable with Logging { val expTable = sc.broadcast(createExpTable()) val bcVocab = sc.broadcast(vocab) val bcVocabHash = sc.broadcast(vocabHash) + try { + doFit(dataset, sc, expTable, bcVocab, bcVocabHash) + } finally { + expTable.destroy(blocking = false) + bcVocab.destroy(blocking = false) + bcVocabHash.destroy(blocking = false) + } + } + + private def doFit[S <: Iterable[String]]( + dataset: RDD[S], sc: SparkContext, + expTable: Broadcast[Array[Float]], + bcVocab: Broadcast[Array[VocabWord]], + bcVocabHash: Broadcast[mutable.HashMap[String, Int]]) = { // each partition is a collection of sentences, // will be translated into arrays of Index integer val sentences: RDD[Array[Int]] = dataset.mapPartitions { sentenceIter => @@ -435,9 +450,6 @@ class Word2Vec extends Serializable with Logging { bcSyn1Global.destroy(false) } newSentences.unpersist() - expTable.destroy(false) - bcVocab.destroy(false) - bcVocabHash.destroy(false) val wordArray = vocab.map(_.word) new Word2VecModel(wordArray.zipWithIndex.toMap, syn0Global) diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index c711e7fa9dc67..10de50306a5ce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -372,16 +372,18 @@ class DecisionTreeClassifierSuite // Categorical splits with tree depth 2 val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 2) - testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, categoricalData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings, + allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, allParamSettings ++ Map("maxDepth" -> 0), - checkModelData) + allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 0598943c3d4be..0cddb37281b39 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -374,7 +374,8 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a165d8a9345cf..4c63a2a88c6c6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.udf class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau @transient var smallValidationDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ + @transient var smallSparseBinaryDataset: Dataset[_] = _ + @transient var smallSparseValidationDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + + // Dataset for testing SparseVector + val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse + val sparse = udf(toSparse) + smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features)) + smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features)) + } /** @@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("Linear SVC binary classification with regularization") { @@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.setRegParam(0.1).fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("params") { @@ -217,7 +232,7 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau } val svm = new LinearSVC() testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, - checkModelData) + LinearSVCSuite.allParamSettings, checkModelData) } } @@ -235,7 +250,7 @@ object LinearSVCSuite { "aggregationDepth" -> 3 ) - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) def generateSVMInput( intercept: Double, weights: Array[Double], @@ -252,5 +267,10 @@ object LinearSVCSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model1.intercept == model2.intercept) + assert(model1.coefficients.equals(model2.coefficients)) + } + } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index d89a958eed45a..affaa573749e8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -2089,7 +2089,7 @@ class LogisticRegressionSuite } val lr = new LogisticRegression() testEstimatorAndModelReadWrite(lr, smallBinaryDataset, LogisticRegressionSuite.allParamSettings, - checkModelData) + LogisticRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 37d7991fe8dd8..4d5d299d1408f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -280,7 +280,8 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa assert(model.theta === model2.theta) } val nb = new NaiveBayes() - testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, + NaiveBayesSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index 44e1585ee514b..c3003cec73b41 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -218,7 +218,8 @@ class RandomForestClassifierSuite val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 2) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala index 30513c1e276ae..200a892f6c694 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/BisectingKMeansSuite.scala @@ -138,8 +138,8 @@ class BisectingKMeansSuite assert(model.clusterCenters === model2.clusterCenters) } val bisectingKMeans = new BisectingKMeans() - testEstimatorAndModelReadWrite( - bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(bisectingKMeans, dataset, BisectingKMeansSuite.allParamSettings, + BisectingKMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala index c500c5b3e365a..61da897b666f4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/GaussianMixtureSuite.scala @@ -163,7 +163,7 @@ class GaussianMixtureSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.gaussians.map(_.cov) === model2.gaussians.map(_.cov)) } val gm = new GaussianMixture() - testEstimatorAndModelReadWrite(gm, dataset, + testEstimatorAndModelReadWrite(gm, dataset, GaussianMixtureSuite.allParamSettings, GaussianMixtureSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala index e10127f7d108f..ca05b9c389f65 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/KMeansSuite.scala @@ -150,7 +150,8 @@ class KMeansSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultR assert(model.clusterCenters === model2.clusterCenters) } val kmeans = new KMeans() - testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(kmeans, dataset, KMeansSuite.allParamSettings, + KMeansSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala index 9aa11fbdbe868..75aa0be61a3ed 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala @@ -250,7 +250,8 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead Vectors.dense(model2.getDocConcentration) absTol 1e-6) } val lda = new LDA() - testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(lda, dataset, LDASuite.allParamSettings, + LDASuite.allParamSettings, checkModelData) } test("read/write DistributedLDAModel") { @@ -271,6 +272,7 @@ class LDASuite extends SparkFunSuite with MLlibTestSparkContext with DefaultRead } val lda = new LDA() testEstimatorAndModelReadWrite(lda, dataset, + LDASuite.allParamSettings ++ Map("optimizer" -> "em"), LDASuite.allParamSettings ++ Map("optimizer" -> "em"), checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala index ab937685a555c..91eac9e733312 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala @@ -63,7 +63,7 @@ class BucketedRandomProjectionLSHSuite } val mh = new BucketedRandomProjectionLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0) - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala index 482e5d54260d4..d6925da97d57e 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/ChiSqSelectorSuite.scala @@ -151,7 +151,8 @@ class ChiSqSelectorSuite extends SparkFunSuite with MLlibTestSparkContext assert(model.selectedFeatures === model2.selectedFeatures) } val nb = new ChiSqSelector - testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(nb, dataset, ChiSqSelectorSuite.allParamSettings, + ChiSqSelectorSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala index 3461cdf82460f..a2f009310fd7a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala @@ -54,7 +54,7 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa } val mh = new MinHashLSH() val settings = Map("inputCol" -> "keys", "outputCol" -> "values") - testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData) + testEstimatorAndModelReadWrite(mh, dataset, settings, settings, checkModelData) } test("hashFunction") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 2d0e63c9d669c..188dffb3dd55f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -64,7 +64,7 @@ class StringIndexerSuite test("StringIndexerUnseen") { val data = Seq((0, "a"), (1, "b"), (4, "b")) - val data2 = Seq((0, "a"), (1, "b"), (2, "c")) + val data2 = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d")) val df = data.toDF("id", "label") val df2 = data2.toDF("id", "label") val indexer = new StringIndexer() @@ -75,22 +75,32 @@ class StringIndexerSuite intercept[SparkException] { indexer.transform(df2).collect() } - val indexerSkipInvalid = new StringIndexer() - .setInputCol("label") - .setOutputCol("labelIndex") - .setHandleInvalid("skip") - .fit(df) + + indexer.setHandleInvalid("skip") // Verify that we skip the c record - val transformed = indexerSkipInvalid.transform(df2) - val attr = Attribute.fromStructField(transformed.schema("labelIndex")) + val transformedSkip = indexer.transform(df2) + val attrSkip = Attribute.fromStructField(transformedSkip.schema("labelIndex")) .asInstanceOf[NominalAttribute] - assert(attr.values.get === Array("b", "a")) - val output = transformed.select("id", "labelIndex").rdd.map { r => + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => (r.getInt(0), r.getDouble(1)) }.collect().toSet // a -> 1, b -> 0 - val expected = Set((0, 1.0), (1, 0.0)) - assert(output === expected) + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + // Verify that we keep the unseen records + val transformedKeep = indexer.transform(df2) + val attrKeep = Attribute.fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, c -> 2, d -> 3 + val expectedKeep = Set((0, 1.0), (1, 0.0), (2, 2.0), (3, 2.0)) + assert(outputKeep === expectedKeep) } test("StringIndexer with a numeric input column") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala index 613cc3d60b227..2043a16c15f1a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala @@ -133,14 +133,22 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul .setSeed(42L) .fit(docDF) - val expectedSimilarity = Array(0.2608488929093532, -0.8271274846926078) - val (synonyms, similarity) = model.findSynonyms("a", 2).rdd.map { + val expected = Map(("b", 0.2608488929093532), ("c", -0.8271274846926078)) + val findSynonymsResult = model.findSynonyms("a", 2).rdd.map { case Row(w: String, sim: Double) => (w, sim) - }.collect().unzip + }.collectAsMap() + + expected.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsResult.get(expectedSynonym).get absTol 1E-5) + } - assert(synonyms === Array("b", "c")) - expectedSimilarity.zip(similarity).foreach { - case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5) + val findSynonymsArrayResult = model.findSynonymsArray("a", 2).toMap + findSynonymsResult.foreach { + case (expectedSynonym, expectedSimilarity) => + assert(findSynonymsArrayResult.contains(expectedSynonym)) + assert(expectedSimilarity ~== findSynonymsArrayResult.get(expectedSynonym).get absTol 1E-5) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala index 74c7461401905..076d55c180548 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/fpm/FPGrowthSuite.scala @@ -99,8 +99,8 @@ class FPGrowthSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul model2.freqItemsets.sort("items").collect()) } val fPGrowth = new FPGrowth() - testEstimatorAndModelReadWrite( - fPGrowth, dataset, FPGrowthSuite.allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(fPGrowth, dataset, FPGrowthSuite.allParamSettings, + FPGrowthSuite.allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala index e494ea89e63bd..a177ed13bf8ef 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/recommendation/ALSSuite.scala @@ -518,37 +518,26 @@ class ALSSuite } test("read/write") { - import ALSSuite._ - val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - val als = new ALS() - allEstimatorParamSettings.foreach { case (p, v) => - als.set(als.getParam(p), v) - } val spark = this.spark import spark.implicits._ - val model = als.fit(ratings.toDF()) - - // Test Estimator save/load - val als2 = testDefaultReadWrite(als) - allEstimatorParamSettings.foreach { case (p, v) => - val param = als.getParam(p) - assert(als.get(param).get === als2.get(param).get) - } + import ALSSuite._ + val (ratings, _) = genExplicitTestData(numUsers = 4, numItems = 4, rank = 1) - // Test Model save/load - val model2 = testDefaultReadWrite(model) - allModelParamSettings.foreach { case (p, v) => - val param = model.getParam(p) - assert(model.get(param).get === model2.get(param).get) - } - assert(model.rank === model2.rank) def getFactors(df: DataFrame): Set[(Int, Array[Float])] = { df.select("id", "features").collect().map { case r => (r.getInt(0), r.getAs[Array[Float]](1)) }.toSet } - assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) - assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + + def checkModelData(model: ALSModel, model2: ALSModel): Unit = { + assert(model.rank === model2.rank) + assert(getFactors(model.userFactors) === getFactors(model2.userFactors)) + assert(getFactors(model.itemFactors) === getFactors(model2.itemFactors)) + } + + val als = new ALS() + testEstimatorAndModelReadWrite(als, ratings.toDF(), allEstimatorParamSettings, + allModelParamSettings, checkModelData) } test("input type validation") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala index 3cd4b0ac308ef..708185a0943df 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/AFTSurvivalRegressionSuite.scala @@ -419,7 +419,8 @@ class AFTSurvivalRegressionSuite } val aft = new AFTSurvivalRegression() testEstimatorAndModelReadWrite(aft, datasetMultivariate, - AFTSurvivalRegressionSuite.allParamSettings, checkModelData) + AFTSurvivalRegressionSuite.allParamSettings, AFTSurvivalRegressionSuite.allParamSettings, + checkModelData) } test("SPARK-15892: Incorrectly merged AFTAggregator with zero total count") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 15fa26e8b5272..0e91284d03d98 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -165,16 +165,17 @@ class DecisionTreeRegressorSuite val categoricalData: DataFrame = TreeTests.setMetadata(rdd, Map(0 -> 2, 1 -> 3), numClasses = 0) testEstimatorAndModelReadWrite(dt, categoricalData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 2 val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) testEstimatorAndModelReadWrite(dt, continuousData, - TreeTests.allParamSettings, checkModelData) + TreeTests.allParamSettings, TreeTests.allParamSettings, checkModelData) // Continuous splits with tree depth 0 testEstimatorAndModelReadWrite(dt, continuousData, + TreeTests.allParamSettings ++ Map("maxDepth" -> 0), TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala index dcf3f9a1ea9b2..03c2f97797bce 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GBTRegressorSuite.scala @@ -184,7 +184,8 @@ class GBTRegressorSuite extends SparkFunSuite with MLlibTestSparkContext val allParamSettings = TreeTests.allParamSettings ++ Map("lossType" -> "squared") val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index add28a72b6808..401911763fa3b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -1418,6 +1418,7 @@ class GeneralizedLinearRegressionSuite val glr = new GeneralizedLinearRegression() testEstimatorAndModelReadWrite(glr, datasetPoissonLog, + GeneralizedLinearRegressionSuite.allParamSettings, GeneralizedLinearRegressionSuite.allParamSettings, checkModelData) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala index 8cbb2acad243e..f41a3601b1fa8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/IsotonicRegressionSuite.scala @@ -178,7 +178,7 @@ class IsotonicRegressionSuite val ir = new IsotonicRegression() testEstimatorAndModelReadWrite(ir, dataset, IsotonicRegressionSuite.allParamSettings, - checkModelData) + IsotonicRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala index 584a1b272f6c8..6a51e75e12a36 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala @@ -985,7 +985,7 @@ class LinearRegressionSuite } val lr = new LinearRegression() testEstimatorAndModelReadWrite(lr, datasetWithWeight, LinearRegressionSuite.allParamSettings, - checkModelData) + LinearRegressionSuite.allParamSettings, checkModelData) } test("should support all NumericType labels and weights, and not support other types") { diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala index c08335f9f84af..3bf0445ebd3dd 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/RandomForestRegressorSuite.scala @@ -124,7 +124,8 @@ class RandomForestRegressorSuite extends SparkFunSuite with MLlibTestSparkContex val continuousData: DataFrame = TreeTests.setMetadata(rdd, Map.empty[Int, Int], numClasses = 0) - testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, checkModelData) + testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings, + allParamSettings, checkModelData) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala index 553b8725b30a3..bfe8f12258bb8 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/util/DefaultReadWriteTest.scala @@ -85,11 +85,12 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => * - Check Params on Estimator and Model * - Compare model data * - * This requires that the [[Estimator]] and [[Model]] share the same set of [[Param]]s. + * This requires that [[Model]]'s [[Param]]s should be a subset of [[Estimator]]'s [[Param]]s. * * @param estimator Estimator to test * @param dataset Dataset to pass to [[Estimator.fit()]] - * @param testParams Set of [[Param]] values to set in estimator + * @param testEstimatorParams Set of [[Param]] values to set in estimator + * @param testModelParams Set of [[Param]] values to set in model * @param checkModelData Method which takes the original and loaded [[Model]] and compares their * data. This method does not need to check [[Param]] values. * @tparam E Type of [[Estimator]] @@ -99,24 +100,25 @@ trait DefaultReadWriteTest extends TempDirectory { self: Suite => E <: Estimator[M] with MLWritable, M <: Model[M] with MLWritable]( estimator: E, dataset: Dataset[_], - testParams: Map[String, Any], + testEstimatorParams: Map[String, Any], + testModelParams: Map[String, Any], checkModelData: (M, M) => Unit): Unit = { // Set some Params to make sure set Params are serialized. - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => estimator.set(estimator.getParam(p), v) } val model = estimator.fit(dataset) // Test Estimator save/load val estimator2 = testDefaultReadWrite(estimator) - testParams.foreach { case (p, v) => + testEstimatorParams.foreach { case (p, v) => val param = estimator.getParam(p) assert(estimator.get(param).get === estimator2.get(param).get) } // Test Model save/load val model2 = testDefaultReadWrite(model) - testParams.foreach { case (p, v) => + testModelParams.foreach { case (p, v) => val param = model.getParam(p) assert(model.get(param).get === model2.get(param).get) } diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56b8c0b95e8a4..bd4528bd21264 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -914,6 +914,10 @@ object MimaExcludes { ) ++ Seq( // [SPARK-17163] Unify logistic regression interface. Private constructor has new signature. ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.this") + ) ++ Seq( + // [SPARK-17498] StringIndexer enhancement for handling unseen labels + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexer"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.feature.StringIndexerModel") ) ++ Seq( // [SPARK-17365][Core] Remove/Kill multiple executors together to reduce RPC call time ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.SparkContext") diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index b199bf282e4f2..3c3fcc8d9b8d8 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1294,8 +1294,8 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha Fit a Generalized Linear Model specified by giving a symbolic description of the linear predictor (link function) and a description of the error distribution (family). It supports - "gaussian", "binomial", "poisson" and "gamma" as family. Valid link functions for each family - is listed below. The first link function of each family is the default one. + "gaussian", "binomial", "poisson", "gamma" and "tweedie" as family. Valid link functions for + each family is listed below. The first link function of each family is the default one. * "gaussian" -> "identity", "log", "inverse" @@ -1305,6 +1305,9 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha * "gamma" -> "inverse", "identity", "log" + * "tweedie" -> power link function specified through "linkPower". \ + The default link power in the tweedie family is 1 - variancePower. + .. seealso:: `GLM `_ >>> from pyspark.ml.linalg import Vectors @@ -1344,7 +1347,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha family = Param(Params._dummy(), "family", "The name of family which is a description of " + "the error distribution to be used in the model. Supported options: " + - "gaussian (default), binomial, poisson and gamma.", + "gaussian (default), binomial, poisson, gamma and tweedie.", typeConverter=TypeConverters.toString) link = Param(Params._dummy(), "link", "The name of link function which provides the " + "relationship between the linear predictor and the mean of the distribution " + @@ -1352,32 +1355,46 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha "and sqrt.", typeConverter=TypeConverters.toString) linkPredictionCol = Param(Params._dummy(), "linkPredictionCol", "link prediction (linear " + "predictor) column name", typeConverter=TypeConverters.toString) + variancePower = Param(Params._dummy(), "variancePower", "The power in the variance function " + + "of the Tweedie distribution which characterizes the relationship " + + "between the variance and mean of the distribution. Only applicable " + + "for the Tweedie family. Supported values: 0 and [1, Inf).", + typeConverter=TypeConverters.toFloat) + linkPower = Param(Params._dummy(), "linkPower", "The index in the power link function. " + + "Only applicable to the Tweedie family.", + typeConverter=TypeConverters.toFloat) @keyword_only def __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ __init__(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) """ super(GeneralizedLinearRegression, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.GeneralizedLinearRegression", self.uid) - self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls") + self._setDefault(family="gaussian", maxIter=25, tol=1e-6, regParam=0.0, solver="irls", + variancePower=0.0) kwargs = self._input_kwargs + self.setParams(**kwargs) @keyword_only @since("2.0.0") def setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None): + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, + variancePower=0.0, linkPower=None): """ setParams(self, labelCol="label", featuresCol="features", predictionCol="prediction", \ family="gaussian", link=None, fitIntercept=True, maxIter=25, tol=1e-6, \ - regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None) + regParam=0.0, weightCol=None, solver="irls", linkPredictionCol=None, \ + variancePower=0.0, linkPower=None) Sets params for generalized linear regression. """ kwargs = self._input_kwargs @@ -1428,6 +1445,34 @@ def getLink(self): """ return self.getOrDefault(self.link) + @since("2.2.0") + def setVariancePower(self, value): + """ + Sets the value of :py:attr:`variancePower`. + """ + return self._set(variancePower=value) + + @since("2.2.0") + def getVariancePower(self): + """ + Gets the value of variancePower or its default value. + """ + return self.getOrDefault(self.variancePower) + + @since("2.2.0") + def setLinkPower(self, value): + """ + Sets the value of :py:attr:`linkPower`. + """ + return self._set(linkPower=value) + + @since("2.2.0") + def getLinkPower(self): + """ + Gets the value of linkPower or its default value. + """ + return self.getOrDefault(self.linkPower) + class GeneralizedLinearRegressionModel(JavaModel, JavaPredictionModel, JavaMLWritable, JavaMLReadable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 352416055791e..f052f5bb770c6 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1223,6 +1223,26 @@ def test_apply_binary_term_freqs(self): ": expected " + str(expected[i]) + ", got " + str(features[i])) +class GeneralizedLinearRegressionTest(SparkSessionTestCase): + + def test_tweedie_distribution(self): + + df = self.spark.createDataFrame( + [(1.0, Vectors.dense(0.0, 0.0)), + (1.0, Vectors.dense(1.0, 2.0)), + (2.0, Vectors.dense(0.0, 0.0)), + (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"]) + + glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6) + model = glr.fit(df) + self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4)) + self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4)) + + model2 = glr.setLinkPower(-1.0).fit(df) + self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4)) + self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4)) + + class ALSTest(SparkSessionTestCase): def test_storage_levels(self): diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala index 55c91675ed3ba..121a02a9be0a1 100644 --- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala +++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -473,4 +473,15 @@ class ReplSuite extends SparkFunSuite { assertDoesNotContain("AssertionError", output) assertDoesNotContain("Exception", output) } + + test("newProductSeqEncoder with REPL defined class") { + val output = runInterpreterInPasteMode("local-cluster[1,4,4096]", + """ + |case class Click(id: Int) + |spark.implicits.newProductSeqEncoder[Click] + """.stripMargin) + + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala index 2760f31b12fa7..1bc6f71860c3f 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterScheduler.scala @@ -152,6 +152,7 @@ private[spark] class MesosClusterScheduler( // is registered with Mesos master. @volatile protected var ready = false private var masterInfo: Option[MasterInfo] = None + private var schedulerDriver: SchedulerDriver = _ def submitDriver(desc: MesosDriverDescription): CreateSubmissionResponse = { val c = new CreateSubmissionResponse @@ -168,9 +169,8 @@ private[spark] class MesosClusterScheduler( return c } c.submissionId = desc.submissionId - queuedDriversState.persist(desc.submissionId, desc) - queuedDrivers += desc c.success = true + addDriverToQueue(desc) } c } @@ -191,7 +191,7 @@ private[spark] class MesosClusterScheduler( // 4. Check if it has already completed. if (launchedDrivers.contains(submissionId)) { val task = launchedDrivers(submissionId) - mesosDriver.killTask(task.taskId) + schedulerDriver.killTask(task.taskId) k.success = true k.message = "Killing running driver" } else if (removeFromQueuedDrivers(submissionId)) { @@ -324,7 +324,7 @@ private[spark] class MesosClusterScheduler( ready = false metricsSystem.report() metricsSystem.stop() - mesosDriver.stop(true) + schedulerDriver.stop(true) } override def registered( @@ -340,6 +340,8 @@ private[spark] class MesosClusterScheduler( stateLock.synchronized { this.masterInfo = Some(masterInfo) + this.schedulerDriver = driver + if (!pendingRecover.isEmpty) { // Start task reconciliation if we need to recover. val statuses = pendingRecover.collect { @@ -506,11 +508,10 @@ private[spark] class MesosClusterScheduler( } private class ResourceOffer( - val offerId: OfferID, - val slaveId: SlaveID, - var resources: JList[Resource]) { + val offer: Offer, + var remainingResources: JList[Resource]) { override def toString(): String = { - s"Offer id: ${offerId}, resources: ${resources}" + s"Offer id: ${offer.getId}, resources: ${remainingResources}" } } @@ -518,16 +519,16 @@ private[spark] class MesosClusterScheduler( val taskId = TaskID.newBuilder().setValue(desc.submissionId).build() val (remainingResources, cpuResourcesToUse) = - partitionResources(offer.resources, "cpus", desc.cores) + partitionResources(offer.remainingResources, "cpus", desc.cores) val (finalResources, memResourcesToUse) = partitionResources(remainingResources.asJava, "mem", desc.mem) - offer.resources = finalResources.asJava + offer.remainingResources = finalResources.asJava val appName = desc.conf.get("spark.app.name") val taskInfo = TaskInfo.newBuilder() .setTaskId(taskId) .setName(s"Driver for ${appName}") - .setSlaveId(offer.slaveId) + .setSlaveId(offer.offer.getSlaveId) .setCommand(buildDriverCommand(desc)) .addAllResources(cpuResourcesToUse.asJava) .addAllResources(memResourcesToUse.asJava) @@ -549,23 +550,29 @@ private[spark] class MesosClusterScheduler( val driverCpu = submission.cores val driverMem = submission.mem logTrace(s"Finding offer to launch driver with cpu: $driverCpu, mem: $driverMem") - val offerOption = currentOffers.find { o => - getResource(o.resources, "cpus") >= driverCpu && - getResource(o.resources, "mem") >= driverMem + val offerOption = currentOffers.find { offer => + getResource(offer.remainingResources, "cpus") >= driverCpu && + getResource(offer.remainingResources, "mem") >= driverMem } if (offerOption.isEmpty) { logDebug(s"Unable to find offer to launch driver id: ${submission.submissionId}, " + s"cpu: $driverCpu, mem: $driverMem") } else { val offer = offerOption.get - val queuedTasks = tasks.getOrElseUpdate(offer.offerId, new ArrayBuffer[TaskInfo]) + val queuedTasks = tasks.getOrElseUpdate(offer.offer.getId, new ArrayBuffer[TaskInfo]) try { val task = createTaskInfo(submission, offer) queuedTasks += task - logTrace(s"Using offer ${offer.offerId.getValue} to launch driver " + + logTrace(s"Using offer ${offer.offer.getId.getValue} to launch driver " + submission.submissionId) - val newState = new MesosClusterSubmissionState(submission, task.getTaskId, offer.slaveId, - None, new Date(), None, getDriverFrameworkID(submission)) + val newState = new MesosClusterSubmissionState( + submission, + task.getTaskId, + offer.offer.getSlaveId, + None, + new Date(), + None, + getDriverFrameworkID(submission)) launchedDrivers(submission.submissionId) = newState launchedDriversState.persist(submission.submissionId, newState) afterLaunchCallback(submission.submissionId) @@ -588,7 +595,7 @@ private[spark] class MesosClusterScheduler( val currentTime = new Date() val currentOffers = offers.asScala.map { - o => new ResourceOffer(o.getId, o.getSlaveId, o.getResourcesList) + offer => new ResourceOffer(offer, offer.getResourcesList) }.toList stateLock.synchronized { @@ -615,8 +622,8 @@ private[spark] class MesosClusterScheduler( driver.launchTasks(Collections.singleton(offerId), taskInfos.asJava) } - for (o <- currentOffers if !tasks.contains(o.offerId)) { - driver.declineOffer(o.offerId) + for (offer <- currentOffers if !tasks.contains(offer.offer.getId)) { + declineOffer(driver, offer.offer, None, Some(getRejectOfferDuration(conf))) } } @@ -662,6 +669,12 @@ private[spark] class MesosClusterScheduler( override def statusUpdate(driver: SchedulerDriver, status: TaskStatus): Unit = { val taskId = status.getTaskId.getValue + + logInfo(s"Received status update: taskId=${taskId}" + + s" state=${status.getState}" + + s" message=${status.getMessage}" + + s" reason=${status.getReason}"); + stateLock.synchronized { if (launchedDrivers.contains(taskId)) { if (status.getReason == Reason.REASON_RECONCILIATION && @@ -682,8 +695,7 @@ private[spark] class MesosClusterScheduler( val newDriverDescription = state.driverDescription.copy( retryState = Some(new MesosClusterRetryState(status, retries, nextRetry, waitTimeSec))) - pendingRetryDrivers += newDriverDescription - pendingRetryDriversState.persist(taskId, newDriverDescription) + addDriverToPending(newDriverDescription, taskId); } else if (TaskState.isFinished(mesosToTaskState(status.getState))) { removeFromLaunchedDrivers(taskId) state.finishDate = Some(new Date()) @@ -746,4 +758,21 @@ private[spark] class MesosClusterScheduler( def getQueuedDriversSize: Int = queuedDrivers.size def getLaunchedDriversSize: Int = launchedDrivers.size def getPendingRetryDriversSize: Int = pendingRetryDrivers.size + + private def addDriverToQueue(desc: MesosDriverDescription): Unit = { + queuedDriversState.persist(desc.submissionId, desc) + queuedDrivers += desc + revive() + } + + private def addDriverToPending(desc: MesosDriverDescription, taskId: String) = { + pendingRetryDriversState.persist(taskId, desc) + pendingRetryDrivers += desc + revive() + } + + private def revive(): Unit = { + logInfo("Reviving Offers.") + schedulerDriver.reviveOffers() + } } diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala index f69c223ab9b6d..85c2e9c76f4b0 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackend.scala @@ -26,6 +26,7 @@ import scala.collection.mutable import scala.concurrent.Future import org.apache.mesos.Protos.{TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.spark.{SecurityManager, SparkContext, SparkException, TaskState} import org.apache.spark.network.netty.SparkTransportConf @@ -119,11 +120,11 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) // Reject offers when we reached the maximum number of cores for this framework private val rejectOfferDurationForReachedMaxCores = - getRejectOfferDurationForReachedMaxCores(sc) + getRejectOfferDurationForReachedMaxCores(sc.conf) // A client for talking to the external shuffle service private val mesosExternalShuffleClient: Option[MesosExternalShuffleClient] = { @@ -146,6 +147,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( @volatile var appId: String = _ + private var schedulerDriver: SchedulerDriver = _ + def newMesosTaskId(): String = { val id = nextMesosTaskId nextMesosTaskId += 1 @@ -252,9 +255,12 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { - appId = frameworkId.getValue - mesosExternalShuffleClient.foreach(_.init(appId)) + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { + this.appId = frameworkId.getValue + this.mesosExternalShuffleClient.foreach(_.init(appId)) + this.schedulerDriver = driver markRegistered() } @@ -293,46 +299,25 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } private def declineUnmatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { offers.foreach { offer => - declineOffer(d, offer, Some("unmet constraints"), + declineOffer( + driver, + offer, + Some("unmet constraints"), Some(rejectOfferDurationForUnmetConstraints)) } } - private def declineOffer( - d: org.apache.mesos.SchedulerDriver, - offer: Offer, - reason: Option[String] = None, - refuseSeconds: Option[Long] = None): Unit = { - - val id = offer.getId.getValue - val offerAttributes = toAttributeMap(offer.getAttributesList) - val mem = getResource(offer.getResourcesList, "mem") - val cpus = getResource(offer.getResourcesList, "cpus") - val ports = getRangeResource(offer.getResourcesList, "ports") - - logDebug(s"Declining offer: $id with attributes: $offerAttributes mem: $mem" + - s" cpu: $cpus port: $ports for $refuseSeconds seconds" + - reason.map(r => s" (reason: $r)").getOrElse("")) - - refuseSeconds match { - case Some(seconds) => - val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() - d.declineOffer(offer.getId, filters) - case _ => d.declineOffer(offer.getId) - } - } - /** * Launches executors on accepted offers, and declines unused offers. Executors are launched * round-robin on offers. * - * @param d SchedulerDriver + * @param driver SchedulerDriver * @param offers Mesos offers that match attribute constraints */ private def handleMatchedOffers( - d: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { + driver: org.apache.mesos.SchedulerDriver, offers: mutable.Buffer[Offer]): Unit = { val tasks = buildMesosTasks(offers) for (offer <- offers) { val offerAttributes = toAttributeMap(offer.getAttributesList) @@ -358,15 +343,19 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( s" ports: $ports") } - d.launchTasks( + driver.launchTasks( Collections.singleton(offer.getId), offerTasks.asJava) } else if (totalCoresAcquired >= maxCores) { // Reject an offer for a configurable amount of time to avoid starving other frameworks - declineOffer(d, offer, Some("reached spark.cores.max"), + declineOffer(driver, + offer, + Some("reached spark.cores.max"), Some(rejectOfferDurationForReachedMaxCores)) } else { - declineOffer(d, offer) + declineOffer( + driver, + offer) } } } @@ -582,8 +571,8 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( // Close the mesos external shuffle client if used mesosExternalShuffleClient.foreach(_.close()) - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } @@ -634,13 +623,13 @@ private[spark] class MesosCoarseGrainedSchedulerBackend( } override def doKillExecutors(executorIds: Seq[String]): Future[Boolean] = Future.successful { - if (mesosDriver == null) { + if (schedulerDriver == null) { logWarning("Asked to kill executors before the Mesos driver was started.") false } else { for (executorId <- executorIds) { val taskId = TaskID.newBuilder().setValue(executorId).build() - mesosDriver.killTask(taskId) + schedulerDriver.killTask(taskId) } // no need to adjust `executorLimitOption` since the AllocationManager already communicated // the desired limit through a call to `doRequestTotalExecutors`. diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala index 7e561916a71e2..215271302ec51 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosFineGrainedSchedulerBackend.scala @@ -24,6 +24,7 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.{HashMap, HashSet} import org.apache.mesos.Protos.{ExecutorInfo => MesosExecutorInfo, TaskInfo => MesosTaskInfo, _} +import org.apache.mesos.SchedulerDriver import org.apache.mesos.protobuf.ByteString import org.apache.spark.{SparkContext, SparkException, TaskState} @@ -65,7 +66,9 @@ private[spark] class MesosFineGrainedSchedulerBackend( // reject offers with mismatched constraints in seconds private val rejectOfferDurationForUnmetConstraints = - getRejectOfferDurationForUnmetConstraints(sc) + getRejectOfferDurationForUnmetConstraints(sc.conf) + + private var schedulerDriver: SchedulerDriver = _ @volatile var appId: String = _ @@ -89,6 +92,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( /** * Creates a MesosExecutorInfo that is used to launch a Mesos executor. + * * @param availableResources Available resources that is offered by Mesos * @param execId The executor id to assign to this new executor. * @return A tuple of the new mesos executor info and the remaining available resources. @@ -178,10 +182,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( override def offerRescinded(d: org.apache.mesos.SchedulerDriver, o: OfferID) {} override def registered( - d: org.apache.mesos.SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) { + driver: org.apache.mesos.SchedulerDriver, + frameworkId: FrameworkID, + masterInfo: MasterInfo) { inClassLoader() { appId = frameworkId.getValue logInfo("Registered as framework ID " + appId) + this.schedulerDriver = driver markRegistered() } } @@ -383,13 +390,13 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def stop() { - if (mesosDriver != null) { - mesosDriver.stop() + if (schedulerDriver != null) { + schedulerDriver.stop() } } override def reviveOffers() { - mesosDriver.reviveOffers() + schedulerDriver.reviveOffers() } override def frameworkMessage( @@ -426,7 +433,7 @@ private[spark] class MesosFineGrainedSchedulerBackend( } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean): Unit = { - mesosDriver.killTask( + schedulerDriver.killTask( TaskID.newBuilder() .setValue(taskId.toString).build() ) diff --git a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala index 1d742fefbbacf..3f25535cb5ec2 100644 --- a/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala +++ b/resource-managers/mesos/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerUtils.scala @@ -46,9 +46,6 @@ trait MesosSchedulerUtils extends Logging { // Lock used to wait for scheduler to be registered private final val registerLatch = new CountDownLatch(1) - // Driver for talking to Mesos - protected var mesosDriver: SchedulerDriver = null - /** * Creates a new MesosSchedulerDriver that communicates to the Mesos master. * @@ -115,10 +112,6 @@ trait MesosSchedulerUtils extends Logging { */ def startScheduler(newDriver: SchedulerDriver): Unit = { synchronized { - if (mesosDriver != null) { - registerLatch.await() - return - } @volatile var error: Option[Exception] = None @@ -128,8 +121,7 @@ trait MesosSchedulerUtils extends Logging { setDaemon(true) override def run() { try { - mesosDriver = newDriver - val ret = mesosDriver.run() + val ret = newDriver.run() logInfo("driver.run() returned with code " + ret) if (ret != null && ret.equals(Status.DRIVER_ABORTED)) { error = Some(new SparkException("Error starting driver, DRIVER_ABORTED")) @@ -379,12 +371,24 @@ trait MesosSchedulerUtils extends Logging { } } - protected def getRejectOfferDurationForUnmetConstraints(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForUnmetConstraints", "120s") + private def getRejectOfferDurationStr(conf: SparkConf): String = { + conf.get("spark.mesos.rejectOfferDuration", "120s") + } + + protected def getRejectOfferDuration(conf: SparkConf): Long = { + Utils.timeStringAsSeconds(getRejectOfferDurationStr(conf)) + } + + protected def getRejectOfferDurationForUnmetConstraints(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForUnmetConstraints", + getRejectOfferDurationStr(conf)) } - protected def getRejectOfferDurationForReachedMaxCores(sc: SparkContext): Long = { - sc.conf.getTimeAsSeconds("spark.mesos.rejectOfferDurationForReachedMaxCores", "120s") + protected def getRejectOfferDurationForReachedMaxCores(conf: SparkConf): Long = { + conf.getTimeAsSeconds( + "spark.mesos.rejectOfferDurationForReachedMaxCores", + getRejectOfferDurationStr(conf)) } /** @@ -438,6 +442,7 @@ trait MesosSchedulerUtils extends Logging { /** * The values of the non-zero ports to be used by the executor process. + * * @param conf the spark config to use * @return the ono-zero values of the ports */ @@ -521,4 +526,33 @@ trait MesosSchedulerUtils extends Logging { case TaskState.KILLED => MesosTaskState.TASK_KILLED case TaskState.LOST => MesosTaskState.TASK_LOST } + + protected def declineOffer( + driver: org.apache.mesos.SchedulerDriver, + offer: Offer, + reason: Option[String] = None, + refuseSeconds: Option[Long] = None): Unit = { + + val id = offer.getId.getValue + val offerAttributes = toAttributeMap(offer.getAttributesList) + val mem = getResource(offer.getResourcesList, "mem") + val cpus = getResource(offer.getResourcesList, "cpus") + val ports = getRangeResource(offer.getResourcesList, "ports") + + logDebug(s"Declining offer: $id with " + + s"attributes: $offerAttributes " + + s"mem: $mem " + + s"cpu: $cpus " + + s"port: $ports " + + refuseSeconds.map(s => s"for ${s} seconds ").getOrElse("") + + reason.map(r => s" (reason: $r)").getOrElse("")) + + refuseSeconds match { + case Some(seconds) => + val filters = Filters.newBuilder().setRefuseSeconds(seconds).build() + driver.declineOffer(offer.getId, filters) + case _ => + driver.declineOffer(offer.getId) + } + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala index b9d098486b675..32967b04cd346 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosClusterSchedulerSuite.scala @@ -53,19 +53,32 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi override def start(): Unit = { ready = true } } scheduler.start() + scheduler.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) + } + + private def testDriverDescription(submissionId: String): MesosDriverDescription = { + new MesosDriverDescription( + "d1", + "jar", + 1000, + 1, + true, + command, + Map[String, String](), + submissionId, + new Date()) } test("can queue drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) - val response2 = - scheduler.submitDriver(new MesosDriverDescription( - "d1", "jar", 1000, 1, true, command, Map[String, String](), "s2", new Date())) + verify(driver, times(1)).reviveOffers() + + val response2 = scheduler.submitDriver(testDriverDescription("s2")) assert(response2.success) + val state = scheduler.getSchedulerState() val queuedDrivers = state.queuedDrivers.toList assert(queuedDrivers(0).submissionId == response.submissionId) @@ -75,9 +88,7 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi test("can kill queued drivers") { setScheduler() - val response = scheduler.submitDriver( - new MesosDriverDescription("d1", "jar", 1000, 1, true, - command, Map[String, String](), "s1", new Date())) + val response = scheduler.submitDriver(testDriverDescription("s1")) assert(response.success) val killResponse = scheduler.killDriver(response.submissionId) assert(killResponse.success) @@ -238,18 +249,10 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi } test("can kill supervised drivers") { - val driver = mock[SchedulerDriver] val conf = new SparkConf() conf.setMaster("mesos://localhost:5050") conf.setAppName("spark mesos") - scheduler = new MesosClusterScheduler( - new BlackHoleMesosClusterPersistenceEngineFactory, conf) { - override def start(): Unit = { - ready = true - mesosDriver = driver - } - } - scheduler.start() + setScheduler(conf.getAll.toMap) val response = scheduler.submitDriver( new MesosDriverDescription("d1", "jar", 100, 1, true, command, @@ -291,4 +294,16 @@ class MesosClusterSchedulerSuite extends SparkFunSuite with LocalSparkContext wi assert(state.launchedDrivers.isEmpty) assert(state.finishedDrivers.size == 1) } + + test("Declines offer with refuse seconds = 120.") { + setScheduler() + + val filter = Filters.newBuilder().setRefuseSeconds(120).build() + val offerId = OfferID.newBuilder().setValue("o1").build() + val offer = Utils.createOffer(offerId.getValue, "s1", 1000, 1) + + scheduler.resourceOffers(driver, Collections.singletonList(offer)) + + verify(driver, times(1)).declineOffer(offerId, filter) + } } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala index 78346e9744957..98033bec6dd68 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/MesosCoarseGrainedSchedulerBackendSuite.scala @@ -552,17 +552,14 @@ class MesosCoarseGrainedSchedulerBackendSuite extends SparkFunSuite override protected def getShuffleClient(): MesosExternalShuffleClient = shuffleClient // override to avoid race condition with the driver thread on `mesosDriver` - override def startScheduler(newDriver: SchedulerDriver): Unit = { - mesosDriver = newDriver - } + override def startScheduler(newDriver: SchedulerDriver): Unit = {} override def stopExecutors(): Unit = { stopCalled = true } - - markRegistered() } backend.start() + backend.registered(driver, Utils.TEST_FRAMEWORK_ID, Utils.TEST_MASTER_INFO) backend } diff --git a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala index 7ebb294aa9080..2a67cbc913ffe 100644 --- a/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala +++ b/resource-managers/mesos/src/test/scala/org/apache/spark/scheduler/cluster/mesos/Utils.scala @@ -28,6 +28,17 @@ import org.mockito.{ArgumentCaptor, Matchers} import org.mockito.Mockito._ object Utils { + + val TEST_FRAMEWORK_ID = FrameworkID.newBuilder() + .setValue("test-framework-id") + .build() + + val TEST_MASTER_INFO = MasterInfo.newBuilder() + .setId("test-master") + .setIp(0) + .setPort(0) + .build() + def createOffer( offerId: String, slaveId: String, diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala index 2fdb70a73c754..41b7b5d60b038 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/security/CredentialUpdater.scala @@ -60,7 +60,7 @@ private[spark] class CredentialUpdater( if (remainingTime <= 0) { credentialUpdater.schedule(credentialUpdaterRunnable, 1, TimeUnit.MINUTES) } else { - logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime millis.") + logInfo(s"Scheduling credentials refresh from HDFS in $remainingTime ms.") credentialUpdater.schedule(credentialUpdaterRunnable, remainingTime, TimeUnit.MILLISECONDS) } } @@ -81,8 +81,8 @@ private[spark] class CredentialUpdater( UserGroupInformation.getCurrentUser.addCredentials(newCredentials) logInfo("Credentials updated from credentials file.") - val remainingTime = getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) - - System.currentTimeMillis() + val remainingTime = (getTimeOfNextUpdateFromFileName(credentialsStatus.getPath) + - System.currentTimeMillis()) if (remainingTime <= 0) TimeUnit.MINUTES.toMillis(1) else remainingTime } else { // If current credential file is older than expected, sleep 1 hour and check again. @@ -100,6 +100,7 @@ private[spark] class CredentialUpdater( TimeUnit.HOURS.toMillis(1) } + logInfo(s"Scheduling credentials refresh from HDFS in $timeToNextUpdate ms.") credentialUpdater.schedule( credentialUpdaterRunnable, timeToNextUpdate, TimeUnit.MILLISECONDS) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala index 56bfc00f97088..cff0efa979932 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystConf.scala @@ -61,6 +61,12 @@ trait CatalystConf { */ def cboEnabled: Boolean + /** Enables join reorder in CBO. */ + def joinReorderEnabled: Boolean + + /** The maximum number of joined nodes allowed in the dynamic programming algorithm. */ + def joinReorderDPThreshold: Int + override def clone(): CatalystConf = throw new CloneNotSupportedException() } @@ -77,6 +83,8 @@ case class SimpleCatalystConf( runSQLonFile: Boolean = true, crossJoinEnabled: Boolean = false, cboEnabled: Boolean = false, + joinReorderEnabled: Boolean = false, + joinReorderDPThreshold: Int = 12, warehousePath: String = "/user/hive/warehouse", sessionLocalTimeZone: String = TimeZone.getDefault().getID) extends CatalystConf { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 2f8489de6b000..93666f14958e9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -598,7 +598,7 @@ class Analyzer( execute(child) } view.copy(child = newChild) - case p @ SubqueryAlias(_, view: View, _) => + case p @ SubqueryAlias(_, view: View) => val newChild = resolveRelation(view) p.copy(child = newChild) case _ => plan @@ -606,7 +606,11 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - i.copy(table = EliminateSubqueryAliases(lookupTableFromCatalog(u))) + lookupTableFromCatalog(u).canonicalized match { + case v: View => + u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") + case other => i.copy(table = other) + } case u: UnresolvedRelation => resolveRelation(u) } @@ -2359,7 +2363,7 @@ class Analyzer( */ object EliminateSubqueryAliases extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case SubqueryAlias(_, child, _) => child + case SubqueryAlias(_, child) => child } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d548d97ab4ab8..0dcb44081f608 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -424,6 +424,9 @@ object FunctionRegistry { expression[BitwiseOr]("|"), expression[BitwiseXor]("^"), + // json + expression[StructToJson]("to_json"), + // Cast aliases (SPARK-16730) castAlias("boolean", BooleanType), castAlias("tinyint", ByteType), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index df45c070f9129..6cfc4a4321316 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -582,7 +582,7 @@ class SessionCatalog( val table = formatTableName(name.table) if (db == globalTempViewManager.database) { globalTempViewManager.get(table).map { viewDef => - SubqueryAlias(table, viewDef, None) + SubqueryAlias(table, viewDef) }.getOrElse(throw new NoSuchTableException(db, table)) } else if (name.database.isDefined || !tempTables.contains(table)) { val metadata = externalCatalog.getTable(db, table) @@ -595,17 +595,17 @@ class SessionCatalog( desc = metadata, output = metadata.schema.toAttributes, child = parser.parsePlan(viewText)) - SubqueryAlias(table, child, Some(name.copy(table = table, database = Some(db)))) + SubqueryAlias(table, child) } else { val tableRelation = CatalogRelation( metadata, // we assume all the columns are nullable. metadata.dataSchema.asNullable.toAttributes, metadata.partitionSchema.asNullable.toAttributes) - SubqueryAlias(table, tableRelation, None) + SubqueryAlias(table, tableRelation) } } else { - SubqueryAlias(table, tempTables(table), None) + SubqueryAlias(table, tempTables(table)) } } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala index c062e4e84bcdd..35ca2a0aa53a2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala @@ -346,7 +346,7 @@ package object dsl { orderSpec: Seq[SortOrder]): LogicalPlan = Window(windowExpressions, partitionSpec, orderSpec, logicalPlan) - def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan, None) + def subquery(alias: Symbol): LogicalPlan = SubqueryAlias(alias.name, logicalPlan) def except(otherPlan: LogicalPlan): LogicalPlan = Except(logicalPlan, otherPlan) @@ -368,7 +368,10 @@ package object dsl { analysis.UnresolvedRelation(TableIdentifier(tableName)), Map.empty, logicalPlan, overwrite, false) - def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan, None) + def as(alias: String): LogicalPlan = SubqueryAlias(alias, logicalPlan) + + def coalesce(num: Integer): LogicalPlan = + Repartition(num, shuffle = false, logicalPlan) def repartition(num: Integer): LogicalPlan = Repartition(num, shuffle = true, logicalPlan) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala index 0782143d465b3..93fc565a53419 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala @@ -45,8 +45,8 @@ import org.apache.spark.util.Utils object ExpressionEncoder { def apply[T : TypeTag](): ExpressionEncoder[T] = { // We convert the not-serializable TypeTag into StructType and ClassTag. - val mirror = typeTag[T].mirror - val tpe = typeTag[T].tpe + val mirror = ScalaReflection.mirror + val tpe = typeTag[T].in(mirror).tpe if (ScalaReflection.optionOfProductType(tpe)) { throw new UnsupportedOperationException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala index a36d3507d92ec..7c60f7d57a99e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.util._ import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} - +import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper} object Cast { @@ -277,9 +277,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // LongConverter private[this] def castToLong(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toLong catch { - case _: NumberFormatException => null - }) + val result = new LongWrapper() + buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1L else 0L) case DateType => @@ -293,9 +292,8 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // IntConverter private[this] def castToInt(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toInt catch { - case _: NumberFormatException => null - }) + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null) case BooleanType => buildCast[Boolean](_, b => if (b) 1 else 0) case DateType => @@ -309,8 +307,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ShortConverter private[this] def castToShort(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toShort catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toShort(result)) { + result.value.toShort + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort) @@ -325,8 +326,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String // ByteConverter private[this] def castToByte(from: DataType): Any => Any = from match { case StringType => - buildCast[UTF8String](_, s => try s.toByte catch { - case _: NumberFormatException => null + val result = new IntWrapper() + buildCast[UTF8String](_, s => if (s.toByte(result)) { + result.value.toByte + } else { + null }) case BooleanType => buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte) @@ -503,11 +507,11 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String case TimestampType => castToTimestampCode(from, ctx) case CalendarIntervalType => castToIntervalCode(from) case BooleanType => castToBooleanCode(from) - case ByteType => castToByteCode(from) - case ShortType => castToShortCode(from) - case IntegerType => castToIntCode(from) + case ByteType => castToByteCode(from, ctx) + case ShortType => castToShortCode(from, ctx) + case IntegerType => castToIntCode(from, ctx) case FloatType => castToFloatCode(from) - case LongType => castToLongCode(from) + case LongType => castToLongCode(from, ctx) case DoubleType => castToDoubleCode(from) case array: ArrayType => @@ -734,13 +738,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = $c != 0;" } - private[this] def castToByteCode(from: DataType): CastFunction = from match { + private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toByte(); - } catch (java.lang.NumberFormatException e) { + if ($c.toByte($wrapper)) { + $evPrim = (byte) $wrapper.value; + } else { $evNull = true; } """ @@ -756,13 +763,18 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (byte) $c;" } - private[this] def castToShortCode(from: DataType): CastFunction = from match { + private[this] def castToShortCode( + from: DataType, + ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toShort(); - } catch (java.lang.NumberFormatException e) { + if ($c.toShort($wrapper)) { + $evPrim = (short) $wrapper.value; + } else { $evNull = true; } """ @@ -778,13 +790,16 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (short) $c;" } - private[this] def castToIntCode(from: DataType): CastFunction = from match { + private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.IntWrapper", wrapper, + s"$wrapper = new UTF8String.IntWrapper();") (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toInt(); - } catch (java.lang.NumberFormatException e) { + if ($c.toInt($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ @@ -800,13 +815,17 @@ case class Cast(child: Expression, dataType: DataType, timeZoneId: Option[String (c, evPrim, evNull) => s"$evPrim = (int) $c;" } - private[this] def castToLongCode(from: DataType): CastFunction = from match { + private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match { case StringType => + val wrapper = ctx.freshName("wrapper") + ctx.addMutableState("UTF8String.LongWrapper", wrapper, + s"$wrapper = new UTF8String.LongWrapper();") + (c, evPrim, evNull) => s""" - try { - $evPrim = $c.toLong(); - } catch (java.lang.NumberFormatException e) { + if ($c.toLong($wrapper)) { + $evPrim = $wrapper.value; + } else { $evNull = true; } """ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala index dbff62efdddb6..18b5f2f7ed2e8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/jsonExpressions.scala @@ -23,11 +23,12 @@ import scala.util.parsing.combinator.RegexParsers import com.fasterxml.jackson.core._ +import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.TypeCheckResult import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json._ -import org.apache.spark.sql.catalyst.util.{GenericArrayData, ParseModes} +import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData, ParseModes} import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils @@ -330,7 +331,7 @@ case class GetJsonObject(json: Expression, path: Expression) // scalastyle:off line.size.limit @ExpressionDescription( - usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Return a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", + usage = "_FUNC_(jsonStr, p1, p2, ..., pn) - Returns a tuple like the function get_json_object, but it takes multiple names. All the input parameters and output column types are string.", extended = """ Examples: > SELECT _FUNC_('{"a":1, "b":2}', 'a', 'b'); @@ -564,6 +565,17 @@ case class JsonToStruct( /** * Converts a [[StructType]] to a json output string. */ +// scalastyle:off line.size.limit +@ExpressionDescription( + usage = "_FUNC_(expr[, options]) - Returns a json string with a given struct value", + extended = """ + Examples: + > SELECT _FUNC_(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT _FUNC_(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + """) +// scalastyle:on line.size.limit case class StructToJson( options: Map[String, String], child: Expression, @@ -573,6 +585,14 @@ case class StructToJson( def this(options: Map[String, String], child: Expression) = this(options, child, None) + // Used in `FunctionRegistry` + def this(child: Expression) = this(Map.empty, child, None) + def this(child: Expression, options: Expression) = + this( + options = StructToJson.convertToMapData(options), + child = child, + timeZoneId = None) + @transient lazy val writer = new CharArrayWriter() @@ -613,3 +633,20 @@ case class StructToJson( override def inputTypes: Seq[AbstractDataType] = StructType :: Nil } + +object StructToJson { + + def convertToMapData(exp: Expression): Map[String, String] = exp match { + case m: CreateMap + if m.dataType.acceptsType(MapType(StringType, StringType, valueContainsNull = false)) => + val arrayMap = m.eval().asInstanceOf[ArrayBasedMapData] + ArrayBasedMapData.toScalaMap(arrayMap).map { case (key, value) => + key.toString -> value.toString + } + case m: CreateMap => + throw new AnalysisException( + s"A type of keys and values in map() must be string, but got ${m.dataType}") + case _ => + throw new AnalysisException("Must use a map() function for options") + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala new file mode 100644 index 0000000000000..b694561e5372d --- /dev/null +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/CostBasedJoinReorder.scala @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import scala.collection.mutable + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{And, Attribute, AttributeSet, Expression, PredicateHelper} +import org.apache.spark.sql.catalyst.plans.{Inner, InnerLike} +import org.apache.spark.sql.catalyst.plans.logical.{BinaryNode, Join, LogicalPlan, Project} +import org.apache.spark.sql.catalyst.rules.Rule + + +/** + * Cost-based join reorder. + * We may have several join reorder algorithms in the future. This class is the entry of these + * algorithms, and chooses which one to use. + */ +case class CostBasedJoinReorder(conf: CatalystConf) extends Rule[LogicalPlan] with PredicateHelper { + def apply(plan: LogicalPlan): LogicalPlan = { + if (!conf.cboEnabled || !conf.joinReorderEnabled) { + plan + } else { + val result = plan transform { + case p @ Project(projectList, j @ Join(_, _, _: InnerLike, _)) => + reorder(p, p.outputSet) + case j @ Join(_, _, _: InnerLike, _) => + reorder(j, j.outputSet) + } + // After reordering is finished, convert OrderedJoin back to Join + result transform { + case oj: OrderedJoin => oj.join + } + } + } + + def reorder(plan: LogicalPlan, output: AttributeSet): LogicalPlan = { + val (items, conditions) = extractInnerJoins(plan) + val result = + // Do reordering if the number of items is appropriate and join conditions exist. + // We also need to check if costs of all items can be evaluated. + if (items.size > 2 && items.size <= conf.joinReorderDPThreshold && conditions.nonEmpty && + items.forall(_.stats(conf).rowCount.isDefined)) { + JoinReorderDP.search(conf, items, conditions, output).getOrElse(plan) + } else { + plan + } + // Set consecutive join nodes ordered. + replaceWithOrderedJoin(result) + } + + /** + * Extract consecutive inner joinable items and join conditions. + * This method works for bushy trees and left/right deep trees. + */ + private def extractInnerJoins(plan: LogicalPlan): (Seq[LogicalPlan], Set[Expression]) = { + plan match { + case Join(left, right, _: InnerLike, cond) => + val (leftPlans, leftConditions) = extractInnerJoins(left) + val (rightPlans, rightConditions) = extractInnerJoins(right) + (leftPlans ++ rightPlans, cond.toSet.flatMap(splitConjunctivePredicates) ++ + leftConditions ++ rightConditions) + case Project(projectList, join) if projectList.forall(_.isInstanceOf[Attribute]) => + extractInnerJoins(join) + case _ => + (Seq(plan), Set()) + } + } + + private def replaceWithOrderedJoin(plan: LogicalPlan): LogicalPlan = plan match { + case j @ Join(left, right, _: InnerLike, cond) => + val replacedLeft = replaceWithOrderedJoin(left) + val replacedRight = replaceWithOrderedJoin(right) + OrderedJoin(j.copy(left = replacedLeft, right = replacedRight)) + case p @ Project(_, join) => + p.copy(child = replaceWithOrderedJoin(join)) + case _ => + plan + } + + /** This is a wrapper class for a join node that has been ordered. */ + private case class OrderedJoin(join: Join) extends BinaryNode { + override def left: LogicalPlan = join.left + override def right: LogicalPlan = join.right + override def output: Seq[Attribute] = join.output + } +} + +/** + * Reorder the joins using a dynamic programming algorithm. This implementation is based on the + * paper: Access Path Selection in a Relational Database Management System. + * http://www.inf.ed.ac.uk/teaching/courses/adbs/AccessPath.pdf + * + * First we put all items (basic joined nodes) into level 0, then we build all two-way joins + * at level 1 from plans at level 0 (single items), then build all 3-way joins from plans + * at previous levels (two-way joins and single items), then 4-way joins ... etc, until we + * build all n-way joins and pick the best plan among them. + * + * When building m-way joins, we only keep the best plan (with the lowest cost) for the same set + * of m items. E.g., for 3-way joins, we keep only the best plan for items {A, B, C} among + * plans (A J B) J C, (A J C) J B and (B J C) J A. + * + * Thus the plans maintained for each level when reordering four items A, B, C, D are as follows: + * level 0: p({A}), p({B}), p({C}), p({D}) + * level 1: p({A, B}), p({A, C}), p({A, D}), p({B, C}), p({B, D}), p({C, D}) + * level 2: p({A, B, C}), p({A, B, D}), p({A, C, D}), p({B, C, D}) + * level 3: p({A, B, C, D}) + * where p({A, B, C, D}) is the final output plan. + * + * For cost evaluation, since physical costs for operators are not available currently, we use + * cardinalities and sizes to compute costs. + */ +object JoinReorderDP extends PredicateHelper { + + def search( + conf: CatalystConf, + items: Seq[LogicalPlan], + conditions: Set[Expression], + topOutput: AttributeSet): Option[LogicalPlan] = { + + // Level i maintains all found plans for i + 1 items. + // Create the initial plans: each plan is a single item with zero cost. + val itemIndex = items.zipWithIndex + val foundPlans = mutable.Buffer[JoinPlanMap](itemIndex.map { + case (item, id) => Set(id) -> JoinPlan(Set(id), item, Set(), Cost(0, 0)) + }.toMap) + + for (lev <- 1 until items.length) { + // Build plans for the next level. + foundPlans += searchLevel(foundPlans, conf, conditions, topOutput) + } + + val plansLastLevel = foundPlans(items.length - 1) + if (plansLastLevel.isEmpty) { + // Failed to find a plan, fall back to the original plan + None + } else { + // There must be only one plan at the last level, which contains all items. + assert(plansLastLevel.size == 1 && plansLastLevel.head._1.size == items.length) + Some(plansLastLevel.head._2.plan) + } + } + + /** Find all possible plans at the next level, based on existing levels. */ + private def searchLevel( + existingLevels: Seq[JoinPlanMap], + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlanMap = { + + val nextLevel = mutable.Map.empty[Set[Int], JoinPlan] + var k = 0 + val lev = existingLevels.length - 1 + // Build plans for the next level from plans at level k (one side of the join) and level + // lev - k (the other side of the join). + // For the lower level k, we only need to search from 0 to lev - k, because when building + // a join from A and B, both A J B and B J A are handled. + while (k <= lev - k) { + val oneSideCandidates = existingLevels(k).values.toSeq + for (i <- oneSideCandidates.indices) { + val oneSidePlan = oneSideCandidates(i) + val otherSideCandidates = if (k == lev - k) { + // Both sides of a join are at the same level, no need to repeat for previous ones. + oneSideCandidates.drop(i) + } else { + existingLevels(lev - k).values.toSeq + } + + otherSideCandidates.foreach { otherSidePlan => + // Should not join two overlapping item sets. + if (oneSidePlan.itemIds.intersect(otherSidePlan.itemIds).isEmpty) { + val joinPlan = buildJoin(oneSidePlan, otherSidePlan, conf, conditions, topOutput) + // Check if it's the first plan for the item set, or it's a better plan than + // the existing one due to lower cost. + val existingPlan = nextLevel.get(joinPlan.itemIds) + if (existingPlan.isEmpty || joinPlan.cost.lessThan(existingPlan.get.cost)) { + nextLevel.update(joinPlan.itemIds, joinPlan) + } + } + } + } + k += 1 + } + nextLevel.toMap + } + + /** Build a new join node. */ + private def buildJoin( + oneJoinPlan: JoinPlan, + otherJoinPlan: JoinPlan, + conf: CatalystConf, + conditions: Set[Expression], + topOutput: AttributeSet): JoinPlan = { + + val onePlan = oneJoinPlan.plan + val otherPlan = otherJoinPlan.plan + // Now both onePlan and otherPlan become intermediate joins, so the cost of the + // new join should also include their own cardinalities and sizes. + val newCost = if (isCartesianProduct(onePlan) || isCartesianProduct(otherPlan)) { + // We consider cartesian product very expensive, thus set a very large cost for it. + // This enables to plan all the cartesian products at the end, because having a cartesian + // product as an intermediate join will significantly increase a plan's cost, making it + // impossible to be selected as the best plan for the items, unless there's no other choice. + Cost( + rows = BigInt(Long.MaxValue) * BigInt(Long.MaxValue), + size = BigInt(Long.MaxValue) * BigInt(Long.MaxValue)) + } else { + val onePlanStats = onePlan.stats(conf) + val otherPlanStats = otherPlan.stats(conf) + Cost( + rows = oneJoinPlan.cost.rows + onePlanStats.rowCount.get + + otherJoinPlan.cost.rows + otherPlanStats.rowCount.get, + size = oneJoinPlan.cost.size + onePlanStats.sizeInBytes + + otherJoinPlan.cost.size + otherPlanStats.sizeInBytes) + } + + // Put the deeper side on the left, tend to build a left-deep tree. + val (left, right) = if (oneJoinPlan.itemIds.size >= otherJoinPlan.itemIds.size) { + (onePlan, otherPlan) + } else { + (otherPlan, onePlan) + } + val joinConds = conditions + .filterNot(l => canEvaluate(l, onePlan)) + .filterNot(r => canEvaluate(r, otherPlan)) + .filter(e => e.references.subsetOf(onePlan.outputSet ++ otherPlan.outputSet)) + // We use inner join whether join condition is empty or not. Since cross join is + // equivalent to inner join without condition. + val newJoin = Join(left, right, Inner, joinConds.reduceOption(And)) + val collectedJoinConds = joinConds ++ oneJoinPlan.joinConds ++ otherJoinPlan.joinConds + val remainingConds = conditions -- collectedJoinConds + val neededAttr = AttributeSet(remainingConds.flatMap(_.references)) ++ topOutput + val neededFromNewJoin = newJoin.outputSet.filter(neededAttr.contains) + val newPlan = + if ((newJoin.outputSet -- neededFromNewJoin).nonEmpty) { + Project(neededFromNewJoin.toSeq, newJoin) + } else { + newJoin + } + + val itemIds = oneJoinPlan.itemIds.union(otherJoinPlan.itemIds) + JoinPlan(itemIds, newPlan, collectedJoinConds, newCost) + } + + private def isCartesianProduct(plan: LogicalPlan): Boolean = plan match { + case Join(_, _, _, None) => true + case Project(_, Join(_, _, _, None)) => true + case _ => false + } + + /** Map[set of item ids, join plan for these items] */ + type JoinPlanMap = Map[Set[Int], JoinPlan] + + /** + * Partial join order in a specific level. + * + * @param itemIds Set of item ids participating in this partial plan. + * @param plan The plan tree with the lowest cost for these items found so far. + * @param joinConds Join conditions included in the plan. + * @param cost The cost of this plan is the sum of costs of all intermediate joins. + */ + case class JoinPlan(itemIds: Set[Int], plan: LogicalPlan, joinConds: Set[Expression], cost: Cost) +} + +/** This class defines the cost model. */ +case class Cost(rows: BigInt, size: BigInt) { + /** + * An empirical value for the weights of cardinality (number of rows) in the cost formula: + * cost = rows * weight + size * (1 - weight), usually cardinality is more important than size. + */ + val weight = 0.7 + + def lessThan(other: Cost): Boolean = { + if (other.rows == 0 || other.size == 0) { + false + } else { + val relativeRows = BigDecimal(rows) / BigDecimal(other.rows) + val relativeSize = BigDecimal(size) / BigDecimal(other.size) + relativeRows * weight + relativeSize * (1 - weight) < 1 + } + } +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 036da3ad2062f..caafa1c134cd4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -118,6 +118,8 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf) SimplifyCreateMapOps) :: Batch("Check Cartesian Products", Once, CheckCartesianProducts(conf)) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Batch("Decimal Optimizations", fixedPoint, DecimalAggregates(conf)) :: Batch("Typed Filter Optimization", fixedPoint, @@ -562,27 +564,23 @@ object CollapseProject extends Rule[LogicalPlan] { } /** - * Combines adjacent [[Repartition]] and [[RepartitionByExpression]] operator combinations - * by keeping only the one. - * 1. For adjacent [[Repartition]]s, collapse into the last [[Repartition]]. - * 2. For adjacent [[RepartitionByExpression]]s, collapse into the last [[RepartitionByExpression]]. - * 3. For a combination of [[Repartition]] and [[RepartitionByExpression]], collapse as a single - * [[RepartitionByExpression]] with the expression and last number of partition. + * Combines adjacent [[RepartitionOperation]] operators */ object CollapseRepartition extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - // Case 1 - case Repartition(numPartitions, shuffle, Repartition(_, _, child)) => - Repartition(numPartitions, shuffle, child) - // Case 2 - case RepartitionByExpression(exprs, RepartitionByExpression(_, child, _), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) - // Case 3 - case Repartition(numPartitions, _, r: RepartitionByExpression) => - r.copy(numPartitions = numPartitions) - // Case 3 - case RepartitionByExpression(exprs, Repartition(_, _, child), numPartitions) => - RepartitionByExpression(exprs, child, numPartitions) + // Case 1: When a Repartition has a child of Repartition or RepartitionByExpression, + // 1) When the top node does not enable the shuffle (i.e., coalesce API), but the child + // enables the shuffle. Returns the child node if the last numPartitions is bigger; + // otherwise, keep unchanged. + // 2) In the other cases, returns the top node with the child's child + case r @ Repartition(_, _, child: RepartitionOperation) => (r.shuffle, child.shuffle) match { + case (false, true) => if (r.numPartitions >= child.numPartitions) child else r + case _ => r.copy(child = child.child) + } + // Case 2: When a RepartitionByExpression has a child of Repartition or RepartitionByExpression + // we can remove the child. + case r @ RepartitionByExpression(_, child: RepartitionOperation, _) => + r.copy(child = child.child) } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala index 4d62cce9da0ac..fb7ce6aecea53 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala @@ -169,7 +169,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { // and Project operators, followed by an optional Filter, followed by an // Aggregate. Traverse the operators recursively. def evalPlan(lp : LogicalPlan) : Map[ExprId, Option[Any]] = lp match { - case SubqueryAlias(_, child, _) => evalPlan(child) + case SubqueryAlias(_, child) => evalPlan(child) case Filter(condition, child) => val bindings = evalPlan(child) if (bindings.isEmpty) bindings @@ -227,7 +227,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart += p bottomPart = child - case s @ SubqueryAlias(_, child, _) => + case s @ SubqueryAlias(_, child) => topPart += s bottomPart = child @@ -298,8 +298,8 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] { topPart.reverse.foreach { case Project(projList, _) => subqueryRoot = Project(projList ++ havingInputs, subqueryRoot) - case s @ SubqueryAlias(alias, _, None) => - subqueryRoot = SubqueryAlias(alias, subqueryRoot, None) + case s @ SubqueryAlias(alias, _) => + subqueryRoot = SubqueryAlias(alias, subqueryRoot) case op => sys.error(s"Unexpected operator $op in corelated subquery") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala index d2e091f4dda69..3cf11adc1953b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala @@ -108,7 +108,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * This is only used for Common Table Expressions. */ override def visitNamedQuery(ctx: NamedQueryContext): SubqueryAlias = withOrigin(ctx) { - SubqueryAlias(ctx.name.getText, plan(ctx.query), None) + SubqueryAlias(ctx.name.getText, plan(ctx.query)) } /** @@ -666,7 +666,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { val tableWithAlias = Option(ctx.strictIdentifier).map(_.getText) match { case Some(strictIdentifier) => - SubqueryAlias(strictIdentifier, table, None) + SubqueryAlias(strictIdentifier, table) case _ => table } tableWithAlias.optionalMap(ctx.sample)(withSample) @@ -731,7 +731,7 @@ class AstBuilder extends SqlBaseBaseVisitor[AnyRef] with Logging { * Create an alias (SubqueryAlias) for a LogicalPlan. */ private def aliasPlan(alias: ParserRuleContext, plan: LogicalPlan): LogicalPlan = { - SubqueryAlias(alias.getText, plan, None) + SubqueryAlias(alias.getText, plan) } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala index 77309ce391a1a..62f68a6d7b528 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/EventTimeWatermark.scala @@ -42,6 +42,13 @@ case class EventTimeWatermark( .putLong(EventTimeWatermark.delayKey, delay.milliseconds) .build() a.withMetadata(updatedMetadata) + } else if (a.metadata.contains(EventTimeWatermark.delayKey)) { + // Remove existing watermark + val updatedMetadata = new MetadataBuilder() + .withMetadata(a.metadata) + .remove(EventTimeWatermark.delayKey) + .build() + a.withMetadata(updatedMetadata) } else { a } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index ccebae3cc2701..31b6ed48a2230 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -752,14 +752,13 @@ case class GlobalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryN } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { - // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero - // (product of children). - 1 - } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum - } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val childStats = child.stats(conf) + val rowCount: BigInt = childStats.rowCount.map(_.min(limit)).getOrElse(limit) + // Don't propagate column stats, because we don't know the distribution after a limit operation + Statistics( + sizeInBytes = EstimationUtils.getOutputSize(output, rowCount, childStats.attributeStats), + rowCount = Some(rowCount), + isBroadcastable = childStats.isBroadcastable) } } @@ -773,21 +772,27 @@ case class LocalLimit(limitExpr: Expression, child: LogicalPlan) extends UnaryNo } override def computeStats(conf: CatalystConf): Statistics = { val limit = limitExpr.eval().asInstanceOf[Int] - val sizeInBytes = if (limit == 0) { + val childStats = child.stats(conf) + if (limit == 0) { // sizeInBytes can't be zero, or sizeInBytes of BinaryNode will also be zero // (product of children). - 1 + Statistics( + sizeInBytes = 1, + rowCount = Some(0), + isBroadcastable = childStats.isBroadcastable) } else { - (limit: Long) * output.map(a => a.dataType.defaultSize).sum + // The output row count of LocalLimit should be the sum of row counts from each partition. + // However, since the number of partitions is not available here, we just use statistics of + // the child. Because the distribution after a limit operation is unknown, we do not propagate + // the column stats. + childStats.copy(attributeStats = AttributeMap(Nil)) } - child.stats(conf).copy(sizeInBytes = sizeInBytes) } } case class SubqueryAlias( alias: String, - child: LogicalPlan, - view: Option[TableIdentifier]) + child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) @@ -816,12 +821,14 @@ case class Sample( override def computeStats(conf: CatalystConf): Statistics = { val ratio = upperBound - lowerBound - // BigInt can't multiply with Double - var sizeInBytes = child.stats(conf).sizeInBytes * (ratio * 100).toInt / 100 + val childStats = child.stats(conf) + var sizeInBytes = EstimationUtils.ceil(BigDecimal(childStats.sizeInBytes) * ratio) if (sizeInBytes == 0) { sizeInBytes = 1 } - child.stats(conf).copy(sizeInBytes = sizeInBytes) + val sampledRowCount = childStats.rowCount.map(c => EstimationUtils.ceil(BigDecimal(c) * ratio)) + // Don't propagate column stats, because we don't know the distribution after a sample operation + Statistics(sizeInBytes, sampledRowCount, isBroadcastable = childStats.isBroadcastable) } override protected def otherCopyArgs: Seq[AnyRef] = isTableSample :: Nil @@ -835,6 +842,15 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { override def output: Seq[Attribute] = child.output } +/** + * A base interface for [[RepartitionByExpression]] and [[Repartition]] + */ +abstract class RepartitionOperation extends UnaryNode { + def shuffle: Boolean + def numPartitions: Int + override def output: Seq[Attribute] = child.output +} + /** * Returns a new RDD that has exactly `numPartitions` partitions. Differs from * [[RepartitionByExpression]] as this method is called directly by DataFrame's, because the user @@ -842,9 +858,8 @@ case class Distinct(child: LogicalPlan) extends UnaryNode { * of the output requires some specific ordering or distribution of the data. */ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) - extends UnaryNode { + extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") - override def output: Seq[Attribute] = child.output } /** @@ -856,12 +871,12 @@ case class Repartition(numPartitions: Int, shuffle: Boolean, child: LogicalPlan) case class RepartitionByExpression( partitionExpressions: Seq[Expression], child: LogicalPlan, - numPartitions: Int) extends UnaryNode { + numPartitions: Int) extends RepartitionOperation { require(numPartitions > 0, s"Number of partitions ($numPartitions) must be positive.") override def maxRows: Option[Long] = child.maxRows - override def output: Seq[Attribute] = child.output + override def shuffle: Boolean = true } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala index 0c928832d7d22..b10785b05d6c7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/statsEstimation/FilterEstimation.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql.catalyst.plans.logical.statsEstimation import scala.collection.immutable.HashSet import scala.collection.mutable +import scala.math.BigDecimal.RoundingMode import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.CatalystConf import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.types._ @@ -52,17 +53,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def estimate: Option[Statistics] = { if (childStats.rowCount.isEmpty) return None - // save a mutable copy of colStats so that we can later change it recursively + // Save a mutable copy of colStats so that we can later change it recursively. colStatsMap.setInitValues(childStats.attributeStats) - // estimate selectivity of this filter predicate - val filterSelectivity: Double = calculateFilterSelectivity(plan.condition) match { - case Some(percent) => percent - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => 1.0 - } + // Estimate selectivity of this filter predicate, and update column stats if needed. + // For not-supported condition, set filter selectivity to a conservative estimate 100% + val filterSelectivity: Double = calculateFilterSelectivity(plan.condition).getOrElse(1.0) - val newColStats = colStatsMap.toColumnStats + val newColStats = if (filterSelectivity == 0) { + // The output is empty, we don't need to keep column stats. + AttributeMap[ColumnStat](Nil) + } else { + colStatsMap.toColumnStats + } val filteredRowCount: BigInt = EstimationUtils.ceil(BigDecimal(childStats.rowCount.get) * filterSelectivity) @@ -74,12 +77,14 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } /** - * Returns a percentage of rows meeting a compound condition in Filter node. - * A compound condition is decomposed into multiple single conditions linked with AND, OR, NOT. + * Returns a percentage of rows meeting a condition in Filter node. + * If it's a single condition, we calculate the percentage directly. + * If it's a compound condition, it is decomposed into multiple single conditions linked with + * AND, OR, NOT. * For logical AND conditions, we need to update stats after a condition estimation * so that the stats will be more accurate for subsequent estimation. This is needed for * range condition such as (c > 40 AND c <= 50) - * For logical OR conditions, we do not update stats after a condition estimation. + * For logical OR and NOT conditions, we do not update stats after a condition estimation. * * @param condition the compound logical expression * @param update a boolean flag to specify if we need to update ColumnStat of a column @@ -90,34 +95,29 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo def calculateFilterSelectivity(condition: Expression, update: Boolean = true): Option[Double] = { condition match { case And(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update) - val percent2 = calculateFilterSelectivity(cond2, update) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(p1 * p2) - case (Some(p1), None) => Some(p1) - case (None, Some(p2)) => Some(p2) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update).getOrElse(1.0) + Some(percent1 * percent2) case Or(cond1, cond2) => - // For ease of debugging, we compute percent1 and percent2 in 2 statements. - val percent1 = calculateFilterSelectivity(cond1, update = false) - val percent2 = calculateFilterSelectivity(cond2, update = false) - (percent1, percent2) match { - case (Some(p1), Some(p2)) => Some(math.min(1.0, p1 + p2 - (p1 * p2))) - case (Some(p1), None) => Some(1.0) - case (None, Some(p2)) => Some(1.0) - case (None, None) => None - } + val percent1 = calculateFilterSelectivity(cond1, update = false).getOrElse(1.0) + val percent2 = calculateFilterSelectivity(cond2, update = false).getOrElse(1.0) + Some(percent1 + percent2 - (percent1 * percent2)) - case Not(cond) => calculateFilterSelectivity(cond, update = false) match { - case Some(percent) => Some(1.0 - percent) - // for not-supported condition, set filter selectivity to a conservative estimate 100% - case None => None - } + case Not(And(cond1, cond2)) => + calculateFilterSelectivity(Or(Not(cond1), Not(cond2)), update = false) + + case Not(Or(cond1, cond2)) => + calculateFilterSelectivity(And(Not(cond1), Not(cond2)), update = false) - case _ => calculateSingleCondition(condition, update) + case Not(cond) => + calculateFilterSelectivity(cond, update = false) match { + case Some(percent) => Some(1.0 - percent) + case None => None + } + + case _ => + calculateSingleCondition(condition, update) } } @@ -225,12 +225,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo } val percent = if (isNull) { - nullPercent.toDouble + nullPercent } else { - 1.0 - nullPercent.toDouble + 1.0 - nullPercent } - Some(percent) + Some(percent.toDouble) } /** @@ -249,17 +249,19 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo attr: Attribute, literal: Literal, update: Boolean): Option[Double] = { + if (!colStatsMap.contains(attr)) { + logDebug("[CBO] No statistics for " + attr) + return None + } + attr.dataType match { - case _: NumericType | DateType | TimestampType => + case _: NumericType | DateType | TimestampType | BooleanType => evaluateBinaryForNumeric(op, attr, literal, update) case StringType | BinaryType => // TODO: It is difficult to support other binary comparisons for String/Binary // type without min/max and advanced statistics like histogram. logDebug("[CBO] No range comparison statistics for String/Binary type " + attr) None - case _ => - // TODO: support boolean type. - None } } @@ -291,6 +293,10 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo * Returns a percentage of rows meeting an equality (=) expression. * This method evaluates the equality predicate for all data types. * + * For EqualNullSafe (<=>), if the literal is not null, result will be the same as EqualTo; + * if the literal is null, the condition will be changed to IsNull after optimization. + * So we don't need specific logic for EqualNullSafe here. + * * @param attr an Attribute (or a column) * @param literal a literal value (or constant) * @param update a boolean flag to specify if we need to update ColumnStat of a given column @@ -323,7 +329,7 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo colStatsMap(attr) = newStats } - Some(1.0 / ndv.toDouble) + Some((1.0 / BigDecimal(ndv)).toDouble) } else { Some(0.0) } @@ -394,12 +400,12 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo // return the filter selectivity. Without advanced statistics such as histograms, // we have to assume uniform distribution. - Some(math.min(1.0, newNdv.toDouble / ndv.toDouble)) + Some(math.min(1.0, (BigDecimal(newNdv) / BigDecimal(ndv)).toDouble)) } /** * Returns a percentage of rows meeting a binary comparison expression. - * This method evaluate expression for Numeric columns only. + * This method evaluate expression for Numeric/Date/Timestamp/Boolean columns. * * @param op a binary comparison operator uch as =, <, <=, >, >= * @param attr an Attribute (or a column) @@ -414,53 +420,66 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo literal: Literal, update: Boolean): Option[Double] = { - var percent = 1.0 val colStat = colStatsMap(attr) - val statsRange = - Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val statsRange = Range(colStat.min, colStat.max, attr.dataType).asInstanceOf[NumericRange] + val max = BigDecimal(statsRange.max) + val min = BigDecimal(statsRange.min) + val ndv = BigDecimal(colStat.distinctCount) // determine the overlapping degree between predicate range and column's range - val literalValueBD = BigDecimal(literal.value.toString) + val numericLiteral = if (literal.dataType == BooleanType) { + if (literal.value.asInstanceOf[Boolean]) BigDecimal(1) else BigDecimal(0) + } else { + BigDecimal(literal.value.toString) + } val (noOverlap: Boolean, completeOverlap: Boolean) = op match { case _: LessThan => - (literalValueBD <= statsRange.min, literalValueBD > statsRange.max) + (numericLiteral <= min, numericLiteral > max) case _: LessThanOrEqual => - (literalValueBD < statsRange.min, literalValueBD >= statsRange.max) + (numericLiteral < min, numericLiteral >= max) case _: GreaterThan => - (literalValueBD >= statsRange.max, literalValueBD < statsRange.min) + (numericLiteral >= max, numericLiteral < min) case _: GreaterThanOrEqual => - (literalValueBD > statsRange.max, literalValueBD <= statsRange.min) + (numericLiteral > max, numericLiteral <= min) } + var percent = BigDecimal(1.0) if (noOverlap) { percent = 0.0 } else if (completeOverlap) { percent = 1.0 } else { - // this is partial overlap case - val literalDouble = literalValueBD.toDouble - val maxDouble = BigDecimal(statsRange.max).toDouble - val minDouble = BigDecimal(statsRange.min).toDouble - + // This is the partial overlap case: // Without advanced statistics like histogram, we assume uniform data distribution. // We just prorate the adjusted range over the initial range to compute filter selectivity. - // For ease of computation, we convert all relevant numeric values to Double. + assert(max > min) percent = op match { case _: LessThan => - (literalDouble - minDouble) / (maxDouble - minDouble) + if (numericLiteral == max) { + // If the literal value is right on the boundary, we can minus the part of the + // boundary value (1/ndv). + 1.0 - 1.0 / ndv + } else { + (numericLiteral - min) / (max - min) + } case _: LessThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.min)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == min) { + // The boundary value is the only satisfying value. + 1.0 / ndv } else { - (literalDouble - minDouble) / (maxDouble - minDouble) + (numericLiteral - min) / (max - min) } case _: GreaterThan => - (maxDouble - literalDouble) / (maxDouble - minDouble) + if (numericLiteral == min) { + 1.0 - 1.0 / ndv + } else { + (max - numericLiteral) / (max - min) + } case _: GreaterThanOrEqual => - if (literalValueBD == BigDecimal(statsRange.max)) { - 1.0 / colStat.distinctCount.toDouble + if (numericLiteral == max) { + 1.0 / ndv } else { - (maxDouble - literalDouble) / (maxDouble - minDouble) + (max - numericLiteral) / (max - min) } } @@ -469,22 +488,25 @@ case class FilterEstimation(plan: Filter, catalystConf: CatalystConf) extends Lo val newValue = convertBoundValue(attr.dataType, literal.value) var newMax = colStat.max var newMin = colStat.min + var newNdv = (ndv * percent).setScale(0, RoundingMode.HALF_UP).toBigInt() + if (newNdv < 1) newNdv = 1 + op match { - case _: GreaterThan => newMin = newValue - case _: GreaterThanOrEqual => newMin = newValue - case _: LessThan => newMax = newValue - case _: LessThanOrEqual => newMax = newValue + case _: GreaterThan | _: GreaterThanOrEqual => + // If new ndv is 1, then new max must be equal to new min. + newMin = if (newNdv == 1) newMax else newValue + case _: LessThan | _: LessThanOrEqual => + newMax = if (newNdv == 1) newMin else newValue } - val newNdv = math.max(math.round(colStat.distinctCount.toDouble * percent), 1) - val newStats = colStat.copy(distinctCount = newNdv, min = newMin, - max = newMax, nullCount = 0) + val newStats = + colStat.copy(distinctCount = newNdv, min = newMin, max = newMax, nullCount = 0) colStatsMap(attr) = newStats } } - Some(percent) + Some(percent.toDouble) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala index 01737e0a17341..893bb1b74cea7 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala @@ -62,23 +62,23 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation)) assertAnalysisError( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Seq("cannot resolve")) checkAnalysis( Project(Seq(UnresolvedAttribute("TbL.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) checkAnalysis( Project(Seq(UnresolvedAttribute("tBl.a")), - SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")), None)), + SubqueryAlias("TbL", UnresolvedRelation(TableIdentifier("TaBlE")))), Project(testRelation.output, testRelation), caseSensitive = false) } @@ -374,8 +374,8 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { val query = Project(Seq($"x.key", $"y.key"), Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - Project(Seq($"y.key"), SubqueryAlias("y", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + Project(Seq($"y.key"), SubqueryAlias("y", input)), Cross, None)) assertAnalysisSuccess(query) @@ -435,10 +435,10 @@ class AnalysisSuite extends AnalysisTest with ShouldMatchers { test("resolve as with an already existed alias") { checkAnalysis( Project(Seq(UnresolvedAttribute("tbl2.a")), - SubqueryAlias("tbl", testRelation, None).as("tbl2")), + SubqueryAlias("tbl", testRelation).as("tbl2")), Project(testRelation.output, testRelation), caseSensitive = false) - checkAnalysis(SubqueryAlias("tbl", testRelation, None).as("tbl2"), testRelation) + checkAnalysis(SubqueryAlias("tbl", testRelation).as("tbl2"), testRelation) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala index 52b8225f70604..7e74dcdef0e27 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalogSuite.scala @@ -439,7 +439,7 @@ class SessionCatalogSuite extends PlanTest { .asInstanceOf[CatalogRelation].tableMeta == metastoreTable1) // Otherwise, we'll first look up a temporary table with the same name assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")) - == SubqueryAlias("tbl1", tempTable1, None)) + == SubqueryAlias("tbl1", tempTable1)) // Then, if that does not exist, look up the relation in the current database sessionCatalog.dropTable(TableIdentifier("tbl1"), ignoreIfNotExists = false, purge = false) assert(sessionCatalog.lookupRelation(TableIdentifier("tbl1")).children.head @@ -456,11 +456,11 @@ class SessionCatalogSuite extends PlanTest { val view = View(desc = metadata, output = metadata.schema.toAttributes, child = CatalystSqlParser.parsePlan(metadata.viewText.get)) comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1", Some("db3"))), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) + SubqueryAlias("view1", view)) // Look up a view using current database of the session catalog. sessionCatalog.setCurrentDatabase("db3") comparePlans(sessionCatalog.lookupRelation(TableIdentifier("view1")), - SubqueryAlias("view1", view, Some(TableIdentifier("view1", Some("db3"))))) + SubqueryAlias("view1", view)) } test("table exists") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala index 8952c72fe42fe..59d2dc46f00ce 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/CollapseRepartitionSuite.scala @@ -32,47 +32,168 @@ class CollapseRepartitionSuite extends PlanTest { val testRelation = LocalRelation('a.int, 'b.int) + + test("collapse two adjacent coalesces into one") { + // Always respects the top coalesces amd removes useless coalesce below coalesce + val query1 = testRelation + .coalesce(10) + .coalesce(20) + val query2 = testRelation + .coalesce(30) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.coalesce(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + test("collapse two adjacent repartitions into one") { - val query = testRelation + // Always respects the top repartition amd removes useless repartition below repartition + val query1 = testRelation + .repartition(10) + .repartition(20) + val query2 = testRelation + .repartition(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("coalesce above repartition") { + // Remove useless coalesce above repartition + val query1 = testRelation .repartition(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.repartition(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .repartition(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) + } + + test("repartition above coalesce") { + // Always respects the top repartition amd removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .repartition(20) + val query2 = testRelation + .coalesce(30) .repartition(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.repartition(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartition and repartitionBy into one") { - val query = testRelation + test("repartitionBy above repartition") { + // Always respects the top repartitionBy amd removes useless repartition + val query1 = testRelation .repartition(10) .distribute('a)(20) + val query2 = testRelation + .repartition(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } - test("collapse repartitionBy and repartition into one") { - val query = testRelation + test("repartitionBy above coalesce") { + // Always respects the top repartitionBy amd removes useless coalesce below repartition + val query1 = testRelation + .coalesce(10) + .distribute('a)(20) + val query2 = testRelation + .coalesce(30) .distribute('a)(20) - .repartition(10) - val optimized = Optimize.execute(query.analyze) - val correctAnswer = testRelation.distribute('a)(10).analyze + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + } + + test("repartition above repartitionBy") { + // Always respects the top repartition amd removes useless distribute below repartition + val query1 = testRelation + .distribute('a)(10) + .repartition(20) + val query2 = testRelation + .distribute('a)(30) + .repartition(20) + + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer = testRelation.repartition(20).analyze + + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) + + } + + test("coalesce above repartitionBy") { + // Remove useless coalesce above repartition + val query1 = testRelation + .distribute('a)(10) + .coalesce(20) + + val optimized1 = Optimize.execute(query1.analyze) + val correctAnswer1 = testRelation.distribute('a)(10).analyze + + comparePlans(optimized1, correctAnswer1) + + // No change in this case + val query2 = testRelation + .distribute('a)(30) + .coalesce(20) + + val optimized2 = Optimize.execute(query2.analyze) + val correctAnswer2 = query2.analyze + + comparePlans(optimized2, correctAnswer2) } test("collapse two adjacent repartitionBys into one") { - val query = testRelation + // Always respects the top repartitionBy + val query1 = testRelation .distribute('b)(10) .distribute('a)(20) + val query2 = testRelation + .distribute('b)(30) + .distribute('a)(20) - val optimized = Optimize.execute(query.analyze) + val optimized1 = Optimize.execute(query1.analyze) + val optimized2 = Optimize.execute(query2.analyze) val correctAnswer = testRelation.distribute('a)(20).analyze - comparePlans(optimized, correctAnswer) + comparePlans(optimized1, correctAnswer) + comparePlans(optimized2, correctAnswer) } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 5bd1bc80c3b8a..589607e3ad5cb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -320,16 +320,16 @@ class ColumnPruningSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Inner, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Inner, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), + Project(Seq($"x.key"), SubqueryAlias("x", input)), BroadcastHint( - Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"y.key"), SubqueryAlias("y", input))), Inner, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala index a8aeedbd62759..9b6d68aee803a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/EliminateSubqueryAliasesSuite.scala @@ -46,13 +46,13 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate top level subquery") { val input = LocalRelation('a.int, 'b.int) - val query = SubqueryAlias("a", input, None) + val query = SubqueryAlias("a", input) comparePlans(afterOptimization(query), input) } test("eliminate mid-tree subquery") { val input = LocalRelation('a.int, 'b.int) - val query = Filter(TrueLiteral, SubqueryAlias("a", input, None)) + val query = Filter(TrueLiteral, SubqueryAlias("a", input)) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) @@ -61,7 +61,7 @@ class EliminateSubqueryAliasesSuite extends PlanTest with PredicateHelper { test("eliminate multiple subqueries") { val input = LocalRelation('a.int, 'b.int) val query = Filter(TrueLiteral, - SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input, None), None), None)) + SubqueryAlias("c", SubqueryAlias("b", SubqueryAlias("a", input)))) comparePlans( afterOptimization(query), Filter(TrueLiteral, LocalRelation('a.int, 'b.int))) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala index 65dd6225cea07..985e49069da90 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala @@ -129,15 +129,15 @@ class JoinOptimizationSuite extends PlanTest { val query = Project(Seq($"x.key", $"y.key"), Join( - SubqueryAlias("x", input, None), - BroadcastHint(SubqueryAlias("y", input, None)), Cross, None)).analyze + SubqueryAlias("x", input), + BroadcastHint(SubqueryAlias("y", input)), Cross, None)).analyze val optimized = Optimize.execute(query) val expected = Join( - Project(Seq($"x.key"), SubqueryAlias("x", input, None)), - BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input, None))), + Project(Seq($"x.key"), SubqueryAlias("x", input)), + BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))), Cross, None).analyze comparePlans(optimized, expected) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala new file mode 100644 index 0000000000000..1b2f7a66b6a0b --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinReorderSuite.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.optimizer + +import org.apache.spark.sql.catalyst.SimpleCatalystConf +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.dsl.plans._ +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap} +import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Join, LogicalPlan} +import org.apache.spark.sql.catalyst.rules.RuleExecutor +import org.apache.spark.sql.catalyst.statsEstimation.{StatsEstimationTestBase, StatsTestPlan} +import org.apache.spark.sql.catalyst.util._ + + +class JoinReorderSuite extends PlanTest with StatsEstimationTestBase { + + override val conf = SimpleCatalystConf( + caseSensitiveAnalysis = true, cboEnabled = true, joinReorderEnabled = true) + + object Optimize extends RuleExecutor[LogicalPlan] { + val batches = + Batch("Operator Optimizations", FixedPoint(100), + CombineFilters, + PushDownPredicate, + PushPredicateThroughJoin, + ColumnPruning, + CollapseProject) :: + Batch("Join Reorder", Once, + CostBasedJoinReorder(conf)) :: Nil + } + + /** Set up tables and columns for testing */ + private val columnInfo: AttributeMap[ColumnStat] = AttributeMap(Seq( + attr("t1.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t1.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t2.k-1-5") -> ColumnStat(distinctCount = 5, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t3.v-1-100") -> ColumnStat(distinctCount = 100, min = Some(1), max = Some(100), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.k-1-2") -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4), + attr("t4.v-1-10") -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + )) + + private val nameToAttr: Map[String, Attribute] = columnInfo.map(kv => kv._1.name -> kv._1) + private val nameToColInfo: Map[String, (Attribute, ColumnStat)] = + columnInfo.map(kv => kv._1.name -> kv) + + // Table t1/t4: big table with two columns + private val t1 = StatsTestPlan( + outputList = Seq("t1.k-1-2", "t1.v-1-10").map(nameToAttr), + rowCount = 1000, + // size = rows * (overhead + column length) + size = Some(1000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t1.k-1-2", "t1.v-1-10").map(nameToColInfo))) + + private val t4 = StatsTestPlan( + outputList = Seq("t4.k-1-2", "t4.v-1-10").map(nameToAttr), + rowCount = 2000, + size = Some(2000 * (8 + 4 + 4)), + attributeStats = AttributeMap(Seq("t4.k-1-2", "t4.v-1-10").map(nameToColInfo))) + + // Table t2/t3: small table with only one column + private val t2 = StatsTestPlan( + outputList = Seq("t2.k-1-5").map(nameToAttr), + rowCount = 20, + size = Some(20 * (8 + 4)), + attributeStats = AttributeMap(Seq("t2.k-1-5").map(nameToColInfo))) + + private val t3 = StatsTestPlan( + outputList = Seq("t3.v-1-100").map(nameToAttr), + rowCount = 100, + size = Some(100 * (8 + 4)), + attributeStats = AttributeMap(Seq("t3.v-1-100").map(nameToColInfo))) + + test("reorder 3 tables") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost = cost(t1 J t2) = 1000 * 20 / 5 = 4000 + // In contrast, the cost of the best plan: + // cost = cost(t1 J t3) = 1000 * 100 / 100 = 1000 < 4000 + // so (t1 J t3) J t2 is better (has lower cost, i.e. intermediate result size) than + // the original order (t1 J t2) J t3. + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables - put cross join at the end") { + val originalPlan = + t1.join(t2).join(t3).where(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .join(t2, Inner, None) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("reorder 3 tables with pure-attribute project") { + val originalPlan = + t1.join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.v-1-10")) + + val bestPlan = + t1.join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select(nameToAttr("t1.k-1-2"), nameToAttr("t1.v-1-10")) + .join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select(nameToAttr("t1.v-1-10")) + + assertEqualPlans(originalPlan, bestPlan) + } + + test("don't reorder if project contains non-attribute") { + val originalPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .select((nameToAttr("t1.k-1-2") + nameToAttr("t2.k-1-5")) as "key", nameToAttr("t1.v-1-10")) + .join(t3, Inner, Some(nameToAttr("t1.v-1-10") === nameToAttr("t3.v-1-100"))) + .select("key".attr) + + assertEqualPlans(originalPlan, originalPlan) + } + + test("reorder 4 tables (bushy tree)") { + val originalPlan = + t1.join(t4).join(t2).join(t3).where((nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2")) && + (nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5")) && + (nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))) + + // The cost of original plan (use only cardinality to simplify explanation): + // cost(t1 J t4) = 1000 * 2000 / 2 = 1000000, cost(t1t4 J t2) = 1000000 * 20 / 5 = 4000000, + // cost = cost(t1 J t4) + cost(t1t4 J t2) = 5000000 + // In contrast, the cost of the best plan (a bushy tree): + // cost(t1 J t2) = 1000 * 20 / 5 = 4000, cost(t4 J t3) = 2000 * 100 / 100 = 2000, + // cost = cost(t1 J t2) + cost(t4 J t3) = 6000 << 5000000. + val bestPlan = + t1.join(t2, Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t2.k-1-5"))) + .join(t4.join(t3, Inner, Some(nameToAttr("t4.v-1-10") === nameToAttr("t3.v-1-100"))), + Inner, Some(nameToAttr("t1.k-1-2") === nameToAttr("t4.k-1-2"))) + + assertEqualPlans(originalPlan, bestPlan) + } + + private def assertEqualPlans( + originalPlan: LogicalPlan, + groundTruthBestPlan: LogicalPlan): Unit = { + val optimized = Optimize.execute(originalPlan.analyze) + val normalized1 = normalizePlan(normalizeExprIds(optimized)) + val normalized2 = normalizePlan(normalizeExprIds(groundTruthBestPlan.analyze)) + if (!sameJoinPlan(normalized1, normalized2)) { + fail( + s""" + |== FAIL: Plans do not match === + |${sideBySide(normalized1.treeString, normalized2.treeString).mkString("\n")} + """.stripMargin) + } + } + + /** Consider symmetry for joins when comparing plans. */ + private def sameJoinPlan(plan1: LogicalPlan, plan2: LogicalPlan): Boolean = { + (plan1, plan2) match { + case (j1: Join, j2: Join) => + (sameJoinPlan(j1.left, j2.left) && sameJoinPlan(j1.right, j2.right)) || + (sameJoinPlan(j1.left, j2.right) && sameJoinPlan(j1.right, j2.left)) + case _ => + plan1 == plan2 + } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala index 67d5d2202b680..411777d6e85a2 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/parser/PlanParserSuite.scala @@ -79,7 +79,7 @@ class PlanParserSuite extends PlanTest { def cte(plan: LogicalPlan, namedPlans: (String, LogicalPlan)*): With = { val ctes = namedPlans.map { case (name, cte) => - name -> SubqueryAlias(name, cte, None) + name -> SubqueryAlias(name, cte) } With(plan, ctes) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala index 3b7e5e938a8e4..e9b7a0c6ad671 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala @@ -62,7 +62,7 @@ abstract class PlanTest extends SparkFunSuite with PredicateHelper { * - Sample the seed will replaced by 0L. * - Join conditions will be resorted by hashCode. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + protected def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan transform { case filter @ Filter(condition: Expression, child: LogicalPlan) => Filter(splitConjunctivePredicates(condition).map(rewriteEqual(_)).sortBy(_.hashCode()) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala new file mode 100644 index 0000000000000..e5dc811c8b7db --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/BasicStatsEstimationSuite.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.statsEstimation + +import org.apache.spark.sql.catalyst.CatalystConf +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference, Literal} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.types.IntegerType + + +class BasicStatsEstimationSuite extends StatsEstimationTestBase { + val attribute = attr("key") + val colStat = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4) + + val plan = StatsTestPlan( + outputList = Seq(attribute), + attributeStats = AttributeMap(Seq(attribute -> colStat)), + rowCount = 10, + // row count * (overhead + column size) + size = Some(10 * (8 + 4))) + + test("limit estimation: limit < child's rowCount") { + val localLimit = LocalLimit(Literal(2), plan) + val globalLimit = GlobalLimit(Literal(2), plan) + // LocalLimit's stats is just its child's stats except column stats + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + checkStats(globalLimit, Statistics(sizeInBytes = 24, rowCount = Some(2))) + } + + test("limit estimation: limit > child's rowCount") { + val localLimit = LocalLimit(Literal(20), plan) + val globalLimit = GlobalLimit(Literal(20), plan) + checkStats(localLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + // Limit is larger than child's rowCount, so GlobalLimit's stats is equal to its child's stats. + checkStats(globalLimit, plan.stats(conf).copy(attributeStats = AttributeMap(Nil))) + } + + test("limit estimation: limit = 0") { + val localLimit = LocalLimit(Literal(0), plan) + val globalLimit = GlobalLimit(Literal(0), plan) + val stats = Statistics(sizeInBytes = 1, rowCount = Some(0)) + checkStats(localLimit, stats) + checkStats(globalLimit, stats) + } + + test("sample estimation") { + val sample = Sample(0.0, 0.5, withReplacement = false, (math.random * 1000).toLong, plan)() + checkStats(sample, Statistics(sizeInBytes = 60, rowCount = Some(5))) + + // Child doesn't have rowCount in stats + val childStats = Statistics(sizeInBytes = 120) + val childPlan = DummyLogicalPlan(childStats, childStats) + val sample2 = + Sample(0.0, 0.11, withReplacement = false, (math.random * 1000).toLong, childPlan)() + checkStats(sample2, Statistics(sizeInBytes = 14)) + } + + test("estimate statistics when the conf changes") { + val expectedDefaultStats = + Statistics( + sizeInBytes = 40, + rowCount = Some(10), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), + isBroadcastable = false) + val expectedCboStats = + Statistics( + sizeInBytes = 4, + rowCount = Some(1), + attributeStats = AttributeMap(Seq( + AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), + isBroadcastable = false) + + val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) + checkStats( + plan, expectedStatsCboOn = expectedCboStats, expectedStatsCboOff = expectedDefaultStats) + } + + /** Check estimated stats when cbo is turned on/off. */ + private def checkStats( + plan: LogicalPlan, + expectedStatsCboOn: Statistics, + expectedStatsCboOff: Statistics): Unit = { + assert(plan.stats(conf.copy(cboEnabled = true)) == expectedStatsCboOn) + // Invalidate statistics + plan.invalidateStatsCache() + assert(plan.stats(conf.copy(cboEnabled = false)) == expectedStatsCboOff) + } + + /** Check estimated stats when it's the same whether cbo is turned on or off. */ + private def checkStats(plan: LogicalPlan, expectedStats: Statistics): Unit = + checkStats(plan, expectedStats, expectedStats) +} + +/** + * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes + * a simple statistics or a cbo estimated statistics based on the conf. + */ +private case class DummyLogicalPlan( + defaultStats: Statistics, + cboStats: Statistics) extends LogicalPlan { + override def output: Seq[Attribute] = Nil + override def children: Seq[LogicalPlan] = Nil + override def computeStats(conf: CatalystConf): Statistics = + if (conf.cboEnabled) cboStats else defaultStats +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala index 8be74ced7bb71..4691913c8c986 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/FilterEstimationSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.statsEstimation import java.sql.Date import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, Filter, Statistics} import org.apache.spark.sql.catalyst.plans.logical.statsEstimation.EstimationUtils._ import org.apache.spark.sql.types._ @@ -33,219 +33,235 @@ class FilterEstimationSuite extends StatsEstimationTestBase { // Suppose our test table has 10 rows and 6 columns. // First column cint has values: 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 // Hence, distinctCount:10, min:1, max:10, nullCount:0, avgLen:4, maxLen:4 - val arInt = AttributeReference("cint", IntegerType)() - val childColStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + val attrInt = AttributeReference("cint", IntegerType)() + val colStatInt = ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), nullCount = 0, avgLen = 4, maxLen = 4) // only 2 values - val arBool = AttributeReference("cbool", BooleanType)() - val childColStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), + val attrBool = AttributeReference("cbool", BooleanType)() + val colStatBool = ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), nullCount = 0, avgLen = 1, maxLen = 1) // Second column cdate has 10 values from 2017-01-01 through 2017-01-10. val dMin = Date.valueOf("2017-01-01") val dMax = Date.valueOf("2017-01-10") - val arDate = AttributeReference("cdate", DateType)() - val childColStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), + val attrDate = AttributeReference("cdate", DateType)() + val colStatDate = ColumnStat(distinctCount = 10, min = Some(dMin), max = Some(dMax), nullCount = 0, avgLen = 4, maxLen = 4) // Fourth column cdecimal has 4 values from 0.20 through 0.80 at increment of 0.20. val decMin = new java.math.BigDecimal("0.200000000000000000") val decMax = new java.math.BigDecimal("0.800000000000000000") - val arDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() - val childColStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), + val attrDecimal = AttributeReference("cdecimal", DecimalType(18, 18))() + val colStatDecimal = ColumnStat(distinctCount = 4, min = Some(decMin), max = Some(decMax), nullCount = 0, avgLen = 8, maxLen = 8) // Fifth column cdouble has 10 double values: 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0 - val arDouble = AttributeReference("cdouble", DoubleType)() - val childColStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), + val attrDouble = AttributeReference("cdouble", DoubleType)() + val colStatDouble = ColumnStat(distinctCount = 10, min = Some(1.0), max = Some(10.0), nullCount = 0, avgLen = 8, maxLen = 8) // Sixth column cstring has 10 String values: // "A0", "A1", "A2", "A3", "A4", "A5", "A6", "A7", "A8", "A9" - val arString = AttributeReference("cstring", StringType)() - val childColStatString = ColumnStat(distinctCount = 10, min = None, max = None, + val attrString = AttributeReference("cstring", StringType)() + val colStatString = ColumnStat(distinctCount = 10, min = None, max = None, nullCount = 0, avgLen = 2, maxLen = 2) + val attributeMap = AttributeMap(Seq( + attrInt -> colStatInt, + attrBool -> colStatBool, + attrDate -> colStatDate, + attrDecimal -> colStatDecimal, + attrDouble -> colStatDouble, + attrString -> colStatString)) + test("cint = 2") { validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint <=> 2") { validateEstimatedStats( - arInt, - Filter(EqualNullSafe(arInt, Literal(2)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualNullSafe(attrInt, Literal(2)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 1, min = Some(2), max = Some(2), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cint = 0") { // This is an out-of-range case since 0 is outside the range [min, max] validateEstimatedStats( - arInt, - Filter(EqualTo(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(EqualTo(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint < 3") { validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint < 0") { // This is a corner case since literal 0 is smaller than min. validateEstimatedStats( - arInt, - Filter(LessThan(arInt, Literal(0)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(LessThan(attrInt, Literal(0)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint <= 3") { validateEstimatedStats( - arInt, - Filter(LessThanOrEqual(arInt, Literal(3)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThanOrEqual(attrInt, Literal(3)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(3), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint > 6") { validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThan(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint > 10") { // This is a corner case since max value is 10. validateEstimatedStats( - arInt, - Filter(GreaterThan(arInt, Literal(10)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(GreaterThan(attrInt, Literal(10)), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint >= 6") { validateEstimatedStats( - arInt, - Filter(GreaterThanOrEqual(arInt, Literal(6)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 5) + Filter(GreaterThanOrEqual(attrInt, Literal(6)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 4, min = Some(6), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 5) } test("cint IS NULL") { validateEstimatedStats( - arInt, - Filter(IsNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 0, min = None, max = None, - nullCount = 0, avgLen = 4, maxLen = 4), - 0) + Filter(IsNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Nil, + expectedRowCount = 0) } test("cint IS NOT NULL") { validateEstimatedStats( - arInt, - Filter(IsNotNull(arInt), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 10) + Filter(IsNotNull(attrInt), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 10) } test("cint > 3 AND cint <= 6") { - val condition = And(GreaterThan(arInt, Literal(3)), LessThanOrEqual(arInt, Literal(6))) + val condition = And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), - nullCount = 0, avgLen = 4, maxLen = 4), - 4) + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(6), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 4) } test("cint = 3 OR cint = 6") { - val condition = Or(EqualTo(arInt, Literal(3)), EqualTo(arInt, Literal(6))) + val condition = Or(EqualTo(attrInt, Literal(3)), EqualTo(attrInt, Literal(6))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) + } + + test("Not(cint > 3 AND cint <= 6)") { + val condition = Not(And(GreaterThan(attrInt, Literal(3)), LessThanOrEqual(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 6) + } + + test("Not(cint <= 3 OR cint > 6)") { + val condition = Not(Or(LessThanOrEqual(attrInt, Literal(3)), GreaterThan(attrInt, Literal(6)))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> colStatInt), + expectedRowCount = 5) + } + + test("Not(cint = 3 AND cstring < 'A8')") { + val condition = Not(And(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) + validateEstimatedStats( + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 10) + } + + test("Not(cint = 3 OR cstring < 'A8')") { + val condition = Not(Or(EqualTo(attrInt, Literal(3)), LessThan(attrString, Literal("A8")))) validateEstimatedStats( - arInt, - Filter(condition, childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(condition, childStatsTestPlan(Seq(attrInt, attrString), 10L)), + Seq(attrInt -> colStatInt, attrString -> colStatString), + expectedRowCount = 9) } test("cint IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(3, 4, 5)), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(InSet(attrInt, Set(3, 4, 5)), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 3, min = Some(3), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cint NOT IN (3, 4, 5)") { validateEstimatedStats( - arInt, - Filter(Not(InSet(arInt, Set(3, 4, 5))), childStatsTestPlan(Seq(arInt), 10L)), - ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), - nullCount = 0, avgLen = 4, maxLen = 4), - 7) + Filter(Not(InSet(attrInt, Set(3, 4, 5))), childStatsTestPlan(Seq(attrInt), 10L)), + Seq(attrInt -> ColumnStat(distinctCount = 10, min = Some(1), max = Some(10), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 7) } test("cbool = true") { validateEstimatedStats( - arBool, - Filter(EqualTo(arBool, Literal(true)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 5) + Filter(EqualTo(attrBool, Literal(true)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cbool > false") { - // bool comparison is not supported yet, so stats remain same. validateEstimatedStats( - arBool, - Filter(GreaterThan(arBool, Literal(false)), childStatsTestPlan(Seq(arBool), 10L)), - ColumnStat(distinctCount = 2, min = Some(false), max = Some(true), - nullCount = 0, avgLen = 1, maxLen = 1), - 10) + Filter(GreaterThan(attrBool, Literal(false)), childStatsTestPlan(Seq(attrBool), 10L)), + Seq(attrBool -> ColumnStat(distinctCount = 1, min = Some(true), max = Some(true), + nullCount = 0, avgLen = 1, maxLen = 1)), + expectedRowCount = 5) } test("cdate = cast('2017-01-02' AS DATE)") { val d20170102 = Date.valueOf("2017-01-02") validateEstimatedStats( - arDate, - Filter(EqualTo(arDate, Literal(d20170102)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), - nullCount = 0, avgLen = 4, maxLen = 4), - 1) + Filter(EqualTo(attrDate, Literal(d20170102)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 1, min = Some(d20170102), max = Some(d20170102), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 1) } test("cdate < cast('2017-01-03' AS DATE)") { val d20170103 = Date.valueOf("2017-01-03") validateEstimatedStats( - arDate, - Filter(LessThan(arDate, Literal(d20170103)), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(LessThan(attrDate, Literal(d20170103)), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 2, min = Some(dMin), max = Some(d20170103), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("""cdate IN ( cast('2017-01-03' AS DATE), @@ -254,133 +270,118 @@ class FilterEstimationSuite extends StatsEstimationTestBase { val d20170104 = Date.valueOf("2017-01-04") val d20170105 = Date.valueOf("2017-01-05") validateEstimatedStats( - arDate, - Filter(In(arDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), - childStatsTestPlan(Seq(arDate), 10L)), - ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), - nullCount = 0, avgLen = 4, maxLen = 4), - 3) + Filter(In(attrDate, Seq(Literal(d20170103), Literal(d20170104), Literal(d20170105))), + childStatsTestPlan(Seq(attrDate), 10L)), + Seq(attrDate -> ColumnStat(distinctCount = 3, min = Some(d20170103), max = Some(d20170105), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 3) } test("cdecimal = 0.400000000000000000") { val dec_0_40 = new java.math.BigDecimal("0.400000000000000000") validateEstimatedStats( - arDecimal, - Filter(EqualTo(arDecimal, Literal(dec_0_40)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), - nullCount = 0, avgLen = 8, maxLen = 8), - 1) + Filter(EqualTo(attrDecimal, Literal(dec_0_40)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 1, min = Some(dec_0_40), max = Some(dec_0_40), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 1) } test("cdecimal < 0.60 ") { val dec_0_60 = new java.math.BigDecimal("0.600000000000000000") validateEstimatedStats( - arDecimal, - Filter(LessThan(arDecimal, Literal(dec_0_60)), - childStatsTestPlan(Seq(arDecimal), 4L)), - ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDecimal, Literal(dec_0_60)), + childStatsTestPlan(Seq(attrDecimal), 4L)), + Seq(attrDecimal -> ColumnStat(distinctCount = 3, min = Some(decMin), max = Some(dec_0_60), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cdouble < 3.0") { validateEstimatedStats( - arDouble, - Filter(LessThan(arDouble, Literal(3.0)), childStatsTestPlan(Seq(arDouble), 10L)), - ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), - nullCount = 0, avgLen = 8, maxLen = 8), - 3) + Filter(LessThan(attrDouble, Literal(3.0)), childStatsTestPlan(Seq(attrDouble), 10L)), + Seq(attrDouble -> ColumnStat(distinctCount = 2, min = Some(1.0), max = Some(3.0), + nullCount = 0, avgLen = 8, maxLen = 8)), + expectedRowCount = 3) } test("cstring = 'A2'") { validateEstimatedStats( - arString, - Filter(EqualTo(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 1, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 1) + Filter(EqualTo(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 1, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 1) } - // There is no min/max statistics for String type. We estimate 10 rows returned. - test("cstring < 'A2'") { + test("cstring < 'A2' - unsupported condition") { validateEstimatedStats( - arString, - Filter(LessThan(arString, Literal("A2")), childStatsTestPlan(Seq(arString), 10L)), - ColumnStat(distinctCount = 10, min = None, max = None, - nullCount = 0, avgLen = 2, maxLen = 2), - 10) + Filter(LessThan(attrString, Literal("A2")), childStatsTestPlan(Seq(attrString), 10L)), + Seq(attrString -> ColumnStat(distinctCount = 10, min = None, max = None, + nullCount = 0, avgLen = 2, maxLen = 2)), + expectedRowCount = 10) } - // This is a corner test case. We want to test if we can handle the case when the number of - // valid values in IN clause is greater than the number of distinct values for a given column. - // For example, column has only 2 distinct values 1 and 6. - // The predicate is: column IN (1, 2, 3, 4, 5). test("cint IN (1, 2, 3, 4, 5)") { + // This is a corner test case. We want to test if we can handle the case when the number of + // valid values in IN clause is greater than the number of distinct values for a given column. + // For example, column has only 2 distinct values 1 and 6. + // The predicate is: column IN (1, 2, 3, 4, 5). val cornerChildColStatInt = ColumnStat(distinctCount = 2, min = Some(1), max = Some(6), nullCount = 0, avgLen = 4, maxLen = 4) val cornerChildStatsTestplan = StatsTestPlan( - outputList = Seq(arInt), + outputList = Seq(attrInt), rowCount = 2L, - attributeStats = AttributeMap(Seq(arInt -> cornerChildColStatInt)) + attributeStats = AttributeMap(Seq(attrInt -> cornerChildColStatInt)) ) validateEstimatedStats( - arInt, - Filter(InSet(arInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), - ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), - nullCount = 0, avgLen = 4, maxLen = 4), - 2) + Filter(InSet(attrInt, Set(1, 2, 3, 4, 5)), cornerChildStatsTestplan), + Seq(attrInt -> ColumnStat(distinctCount = 2, min = Some(1), max = Some(5), + nullCount = 0, avgLen = 4, maxLen = 4)), + expectedRowCount = 2) } private def childStatsTestPlan(outList: Seq[Attribute], tableRowCount: BigInt): StatsTestPlan = { StatsTestPlan( outputList = outList, rowCount = tableRowCount, - attributeStats = AttributeMap(Seq( - arInt -> childColStatInt, - arBool -> childColStatBool, - arDate -> childColStatDate, - arDecimal -> childColStatDecimal, - arDouble -> childColStatDouble, - arString -> childColStatString - )) - ) + attributeStats = AttributeMap(outList.map(a => a -> attributeMap(a)))) } private def validateEstimatedStats( - ar: AttributeReference, filterNode: Filter, - expectedColStats: ColumnStat, - rowCount: Int): Unit = { - - val expectedAttrStats = toAttributeMap(Seq(ar.name -> expectedColStats), filterNode) - val expectedSizeInBytes = getOutputSize(filterNode.output, rowCount, expectedAttrStats) - - val filteredStats = filterNode.stats(conf) - assert(filteredStats.sizeInBytes == expectedSizeInBytes) - assert(filteredStats.rowCount.get == rowCount) - assert(filteredStats.attributeStats(ar) == expectedColStats) - - // If the filter has a binary operator (including those nested inside - // AND/OR/NOT), swap the sides of the attribte and the literal, reverse the - // operator, and then check again. - val rewrittenFilter = filterNode transformExpressionsDown { - case EqualTo(ar: AttributeReference, l: Literal) => - EqualTo(l, ar) - - case LessThan(ar: AttributeReference, l: Literal) => - GreaterThan(l, ar) - case LessThanOrEqual(ar: AttributeReference, l: Literal) => - GreaterThanOrEqual(l, ar) - - case GreaterThan(ar: AttributeReference, l: Literal) => - LessThan(l, ar) - case GreaterThanOrEqual(ar: AttributeReference, l: Literal) => - LessThanOrEqual(l, ar) + expectedColStats: Seq[(Attribute, ColumnStat)], + expectedRowCount: Int): Unit = { + + // If the filter has a binary operator (including those nested inside AND/OR/NOT), swap the + // sides of the attribute and the literal, reverse the operator, and then check again. + val swappedFilter = filterNode transformExpressionsDown { + case EqualTo(attr: Attribute, l: Literal) => + EqualTo(l, attr) + + case LessThan(attr: Attribute, l: Literal) => + GreaterThan(l, attr) + case LessThanOrEqual(attr: Attribute, l: Literal) => + GreaterThanOrEqual(l, attr) + + case GreaterThan(attr: Attribute, l: Literal) => + LessThan(l, attr) + case GreaterThanOrEqual(attr: Attribute, l: Literal) => + LessThanOrEqual(l, attr) + } + + val testFilters = if (swappedFilter != filterNode) { + Seq(swappedFilter, filterNode) + } else { + Seq(filterNode) } - if (rewrittenFilter != filterNode) { - validateEstimatedStats(ar, rewrittenFilter, expectedColStats, rowCount) + testFilters.foreach { filter => + val expectedAttributeMap = AttributeMap(expectedColStats) + val expectedStats = Statistics( + sizeInBytes = getOutputSize(filter.output, expectedRowCount, expectedAttributeMap), + rowCount = Some(expectedRowCount), + attributeStats = expectedAttributeMap) + assert(filter.stats(conf) == expectedStats) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala deleted file mode 100644 index 212d57a9bcf95..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsConfSuite.scala +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.catalyst.statsEstimation - -import org.apache.spark.sql.catalyst.CatalystConf -import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, AttributeReference} -import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LogicalPlan, Statistics} -import org.apache.spark.sql.types.IntegerType - - -class StatsConfSuite extends StatsEstimationTestBase { - test("estimate statistics when the conf changes") { - val expectedDefaultStats = - Statistics( - sizeInBytes = 40, - rowCount = Some(10), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(10, Some(1), Some(10), 0, 4, 4))), - isBroadcastable = false) - val expectedCboStats = - Statistics( - sizeInBytes = 4, - rowCount = Some(1), - attributeStats = AttributeMap(Seq( - AttributeReference("c1", IntegerType)() -> ColumnStat(1, Some(5), Some(5), 0, 4, 4))), - isBroadcastable = false) - - val plan = DummyLogicalPlan(defaultStats = expectedDefaultStats, cboStats = expectedCboStats) - // Return the statistics estimated by cbo - assert(plan.stats(conf.copy(cboEnabled = true)) == expectedCboStats) - // Invalidate statistics - plan.invalidateStatsCache() - // Return the simple statistics - assert(plan.stats(conf.copy(cboEnabled = false)) == expectedDefaultStats) - } -} - -/** - * This class is used for unit-testing the cbo switch, it mimics a logical plan which computes - * a simple statistics or a cbo estimated statistics based on the conf. - */ -private case class DummyLogicalPlan( - defaultStats: Statistics, - cboStats: Statistics) extends LogicalPlan { - override def output: Seq[Attribute] = Nil - override def children: Seq[LogicalPlan] = Nil - override def computeStats(conf: CatalystConf): Statistics = - if (conf.cboEnabled) cboStats else defaultStats -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala index c56b41ce37636..9b2b8dbe1bf4a 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/statsEstimation/StatsEstimationTestBase.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.plans.logical.{ColumnStat, LeafNode, Logica import org.apache.spark.sql.types.{IntegerType, StringType} -class StatsEstimationTestBase extends SparkFunSuite { +trait StatsEstimationTestBase extends SparkFunSuite { /** Enable stats estimation based on CBO. */ protected val conf = SimpleCatalystConf(caseSensitiveAnalysis = true, cboEnabled = true) @@ -48,7 +48,7 @@ class StatsEstimationTestBase extends SparkFunSuite { /** * This class is used for unit-testing. It's a logical plan whose output and stats are passed in. */ -protected case class StatsTestPlan( +case class StatsTestPlan( outputList: Seq[Attribute], rowCount: BigInt, attributeStats: AttributeMap[ColumnStat], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 1b04623596073..16edb35b1d43f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1093,7 +1093,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def as(alias: String): Dataset[T] = withTypedPlan { - SubqueryAlias(alias, logicalPlan, None) + SubqueryAlias(alias, logicalPlan) } /** @@ -2441,11 +2441,11 @@ class Dataset[T] private[sql]( } /** - * Returns a new Dataset that has exactly `numPartitions` partitions. - * Similar to coalesce defined on an `RDD`, this operation results in a narrow dependency, e.g. - * if you go from 1000 partitions to 100 partitions, there will not be a shuffle, instead each of - * the 100 new partitions will claim 10 of the current partitions. If a larger number of - * partitions is requested, it will stay at the current number of partitions. + * Returns a new Dataset that has exactly `numPartitions` partitions, when the fewer partitions + * are requested. If a larger number of partitions is requested, it will stay at the current + * number of partitions. Similar to coalesce defined on an `RDD`, this operation results in + * a narrow dependency, e.g. if you go from 1000 partitions to 100 partitions, there will not + * be a shuffle, instead each of the 100 new partitions will claim 10 of the current partitions. * * However, if you're doing a drastic coalesce, e.g. to numPartitions = 1, * this may result in your computation taking place on fewer nodes than diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala index 80138510dc9ee..0ea806d6cb50b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CacheManager.scala @@ -19,6 +19,8 @@ package org.apache.spark.sql.execution import java.util.concurrent.locks.ReentrantReadWriteLock +import scala.collection.JavaConverters._ + import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.internal.Logging @@ -45,7 +47,7 @@ case class CachedData(plan: LogicalPlan, cachedRepresentation: InMemoryRelation) class CacheManager extends Logging { @transient - private val cachedData = new scala.collection.mutable.ArrayBuffer[CachedData] + private val cachedData = new java.util.LinkedList[CachedData] @transient private val cacheLock = new ReentrantReadWriteLock @@ -70,7 +72,7 @@ class CacheManager extends Logging { /** Clears all cached tables. */ def clearCache(): Unit = writeLock { - cachedData.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) + cachedData.asScala.foreach(_.cachedRepresentation.cachedColumnBuffers.unpersist()) cachedData.clear() } @@ -88,46 +90,81 @@ class CacheManager extends Logging { query: Dataset[_], tableName: Option[String] = None, storageLevel: StorageLevel = MEMORY_AND_DISK): Unit = writeLock { - val planToCache = query.queryExecution.analyzed + val planToCache = query.logicalPlan if (lookupCachedData(planToCache).nonEmpty) { logWarning("Asked to cache already cached data.") } else { val sparkSession = query.sparkSession - cachedData += - CachedData( - planToCache, - InMemoryRelation( - sparkSession.sessionState.conf.useCompression, - sparkSession.sessionState.conf.columnBatchSize, - storageLevel, - sparkSession.sessionState.executePlan(planToCache).executedPlan, - tableName)) + cachedData.add(CachedData( + planToCache, + InMemoryRelation( + sparkSession.sessionState.conf.useCompression, + sparkSession.sessionState.conf.columnBatchSize, + storageLevel, + sparkSession.sessionState.executePlan(planToCache).executedPlan, + tableName))) } } /** - * Tries to remove the data for the given [[Dataset]] from the cache. - * No operation, if it's already uncached. + * Un-cache all the cache entries that refer to the given plan. + */ + def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Unit = writeLock { + uncacheQuery(query.sparkSession, query.logicalPlan, blocking) + } + + /** + * Un-cache all the cache entries that refer to the given plan. */ - def uncacheQuery(query: Dataset[_], blocking: Boolean = true): Boolean = writeLock { - val planToCache = query.queryExecution.analyzed - val dataIndex = cachedData.indexWhere(cd => planToCache.sameResult(cd.plan)) - val found = dataIndex >= 0 - if (found) { - cachedData(dataIndex).cachedRepresentation.cachedColumnBuffers.unpersist(blocking) - cachedData.remove(dataIndex) + def uncacheQuery(spark: SparkSession, plan: LogicalPlan, blocking: Boolean): Unit = writeLock { + val it = cachedData.iterator() + while (it.hasNext) { + val cd = it.next() + if (cd.plan.find(_.sameResult(plan)).isDefined) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist(blocking) + it.remove() + } } - found + } + + /** + * Tries to re-cache all the cache entries that refer to the given plan. + */ + def recacheByPlan(spark: SparkSession, plan: LogicalPlan): Unit = writeLock { + recacheByCondition(spark, _.find(_.sameResult(plan)).isDefined) + } + + private def recacheByCondition(spark: SparkSession, condition: LogicalPlan => Boolean): Unit = { + val it = cachedData.iterator() + val needToRecache = scala.collection.mutable.ArrayBuffer.empty[CachedData] + while (it.hasNext) { + val cd = it.next() + if (condition(cd.plan)) { + cd.cachedRepresentation.cachedColumnBuffers.unpersist() + // Remove the cache entry before we create a new one, so that we can have a different + // physical plan. + it.remove() + val newCache = InMemoryRelation( + useCompression = cd.cachedRepresentation.useCompression, + batchSize = cd.cachedRepresentation.batchSize, + storageLevel = cd.cachedRepresentation.storageLevel, + child = spark.sessionState.executePlan(cd.plan).executedPlan, + tableName = cd.cachedRepresentation.tableName) + needToRecache += cd.copy(cachedRepresentation = newCache) + } + } + + needToRecache.foreach(cachedData.add) } /** Optionally returns cached data for the given [[Dataset]] */ def lookupCachedData(query: Dataset[_]): Option[CachedData] = readLock { - lookupCachedData(query.queryExecution.analyzed) + lookupCachedData(query.logicalPlan) } /** Optionally returns cached data for the given [[LogicalPlan]]. */ def lookupCachedData(plan: LogicalPlan): Option[CachedData] = readLock { - cachedData.find(cd => plan.sameResult(cd.plan)) + cachedData.asScala.find(cd => plan.sameResult(cd.plan)) } /** Replaces segments of the given logical plan with cached versions where possible. */ @@ -145,40 +182,17 @@ class CacheManager extends Logging { } /** - * Invalidates the cache of any data that contains `plan`. Note that it is possible that this - * function will over invalidate. - */ - def invalidateCache(plan: LogicalPlan): Unit = writeLock { - cachedData.foreach { - case data if data.plan.collect { case p if p.sameResult(plan) => p }.nonEmpty => - data.cachedRepresentation.recache() - case _ => - } - } - - /** - * Invalidates the cache of any data that contains `resourcePath` in one or more + * Tries to re-cache all the cache entries that contain `resourcePath` in one or more * `HadoopFsRelation` node(s) as part of its logical plan. */ - def invalidateCachedPath( - sparkSession: SparkSession, resourcePath: String): Unit = writeLock { + def recacheByPath(spark: SparkSession, resourcePath: String): Unit = writeLock { val (fs, qualifiedPath) = { val path = new Path(resourcePath) - val fs = path.getFileSystem(sparkSession.sessionState.newHadoopConf()) - (fs, path.makeQualified(fs.getUri, fs.getWorkingDirectory)) + val fs = path.getFileSystem(spark.sessionState.newHadoopConf()) + (fs, fs.makeQualified(path)) } - cachedData.filter { - case data if data.plan.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined => true - case _ => false - }.foreach { data => - val dataIndex = cachedData.indexWhere(cd => data.plan.sameResult(cd.plan)) - if (dataIndex >= 0) { - data.cachedRepresentation.cachedColumnBuffers.unpersist(blocking = true) - cachedData.remove(dataIndex) - } - sparkSession.sharedState.cacheManager.cacheQuery(Dataset.ofRows(sparkSession, data.plan)) - } + recacheByCondition(spark, _.find(lookupAndRefresh(_, fs, qualifiedPath)).isDefined) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala index c106163741278..00d1d6d2701f2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlParser.scala @@ -1331,6 +1331,15 @@ class SparkSqlAstBuilder(conf: SQLConf) extends AstBuilder { if (ctx.identifierList != null) { operationNotAllowed("CREATE VIEW ... PARTITIONED ON", ctx) } else { + // CREATE VIEW ... AS INSERT INTO is not allowed. + ctx.query.queryNoWith match { + case s: SingleInsertQueryContext if s.insertInto != null => + operationNotAllowed("CREATE VIEW ... AS INSERT INTO", ctx) + case _: MultiInsertQueryContext => + operationNotAllowed("CREATE VIEW ... AS FROM ... [INSERT INTO ...]+", ctx) + case _ => // OK + } + val userSpecifiedColumns = Option(ctx.identifierCommentList).toSeq.flatMap { icl => icl.identifierComment.asScala.map { ic => ic.identifier.getText -> Option(ic.STRING).map(string) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 37bd95e737786..36037ac003728 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -85,12 +85,6 @@ case class InMemoryRelation( buildBuffers() } - def recache(): Unit = { - _cachedColumnBuffers.unpersist() - _cachedColumnBuffers = null - buildBuffers() - } - private def buildBuffers(): Unit = { val output = child.output val cached = child.execute().mapPartitionsInternal { rowIterator => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala index 7afa4e78a3786..5f12830ee621f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/SetCommand.scala @@ -60,6 +60,23 @@ case class SetCommand(kv: Option[(String, Option[String])]) extends RunnableComm } (keyValueOutput, runFunc) + case Some((SQLConf.Replaced.MAPREDUCE_JOB_REDUCES, Some(value))) => + val runFunc = (sparkSession: SparkSession) => { + logWarning( + s"Property ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} is Hadoop's property, " + + s"automatically converted to ${SQLConf.SHUFFLE_PARTITIONS.key} instead.") + if (value.toInt < 1) { + val msg = + s"Setting negative ${SQLConf.Replaced.MAPREDUCE_JOB_REDUCES} for automatically " + + "determining the number of reducers is not supported." + throw new IllegalArgumentException(msg) + } else { + sparkSession.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, value) + Seq(Row(SQLConf.SHUFFLE_PARTITIONS.key, value)) + } + } + (keyValueOutput, runFunc) + case Some((key @ SetCommand.VariableName(name), Some(value))) => val runFunc = (sparkSession: SparkSession) => { sparkSession.conf.set(name, value) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala index b5c60423514cb..9d3c55060dfb6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/ddl.scala @@ -199,8 +199,7 @@ case class DropTableCommand( } } try { - sparkSession.sharedState.cacheManager.uncacheQuery( - sparkSession.table(tableName.quotedString)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } catch { case _: NoSuchTableException if ifExists => case NonFatal(e) => log.warn(e.toString, e) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala index 921c84895598c..00f0acab21aa2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/views.scala @@ -23,9 +23,9 @@ import org.apache.spark.sql.{AnalysisException, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedRelation} import org.apache.spark.sql.catalyst.catalog.{CatalogStorageFormat, CatalogTable, CatalogTableType} -import org.apache.spark.sql.catalyst.expressions.Alias +import org.apache.spark.sql.catalyst.expressions.{Alias, SubqueryExpression} import org.apache.spark.sql.catalyst.plans.QueryPlan -import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, View} import org.apache.spark.sql.types.MetadataBuilder @@ -154,6 +154,10 @@ case class CreateViewCommand( } else if (tableMetadata.tableType != CatalogTableType.VIEW) { throw new AnalysisException(s"$name is not a view") } else if (replace) { + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val viewIdent = tableMetadata.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + // Handles `CREATE OR REPLACE VIEW v0 AS SELECT ...` catalog.alterTable(prepareTable(sparkSession, analyzedPlan)) } else { @@ -283,6 +287,10 @@ case class AlterViewAsCommand( throw new AnalysisException(s"${viewMeta.identifier} is not a view.") } + // Detect cyclic view reference on ALTER VIEW. + val viewIdent = viewMeta.identifier + checkCyclicViewReference(analyzedPlan, Seq(viewIdent), viewIdent) + val newProperties = generateViewProperties(viewMeta.properties, session, analyzedPlan) val updatedViewMeta = viewMeta.copy( @@ -358,4 +366,53 @@ object ViewHelper { generateViewDefaultDatabase(viewDefaultDatabase) ++ generateQueryColumnNames(queryOutput) } + + /** + * Recursively search the logical plan to detect cyclic view references, throw an + * AnalysisException if cycle detected. + * + * A cyclic view reference is a cycle of reference dependencies, for example, if the following + * statements are executed: + * CREATE VIEW testView AS SELECT id FROM tbl + * CREATE VIEW testView2 AS SELECT id FROM testView + * ALTER VIEW testView AS SELECT * FROM testView2 + * The view `testView` references `testView2`, and `testView2` also references `testView`, + * therefore a reference cycle (testView -> testView2 -> testView) exists. + * + * @param plan the logical plan we detect cyclic view references from. + * @param path the path between the altered view and current node. + * @param viewIdent the table identifier of the altered view, we compare two views by the + * `desc.identifier`. + */ + def checkCyclicViewReference( + plan: LogicalPlan, + path: Seq[TableIdentifier], + viewIdent: TableIdentifier): Unit = { + plan match { + case v: View => + val ident = v.desc.identifier + val newPath = path :+ ident + // If the table identifier equals to the `viewIdent`, current view node is the same with + // the altered view. We detect a view reference cycle, should throw an AnalysisException. + if (ident == viewIdent) { + throw new AnalysisException(s"Recursive view $viewIdent detected " + + s"(cycle: ${newPath.mkString(" -> ")})") + } else { + v.children.foreach { child => + checkCyclicViewReference(child, newPath, viewIdent) + } + } + case _ => + plan.children.foreach(child => checkCyclicViewReference(child, path, viewIdent)) + } + + // Detect cyclic references from subqueries. + plan.expressions.foreach { expr => + expr match { + case s: SubqueryExpression => + checkCyclicViewReference(s.plan, path, viewIdent) + case _ => // Do nothing. + } + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 950e5ca0d6210..30a09a9ad3370 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -341,7 +341,7 @@ object FileFormatWriter extends Logging { Seq(Cast(c, StringType, Option(desc.timeZoneId))), Seq(StringType)) val str = If(IsNull(c), Literal(ExternalCatalogUtils.DEFAULT_PARTITION_NAME), escaped) - val partitionName = Literal(c.name + "=") :: str :: Nil + val partitionName = Literal(ExternalCatalogUtils.escapePathName(c.name) + "=") :: str :: Nil if (i == 0) partitionName else Literal(Path.SEPARATOR) :: partitionName } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala index b2ff68a833fea..a813829d50cb1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/InsertIntoDataSourceCommand.scala @@ -42,8 +42,9 @@ case class InsertIntoDataSourceCommand( val df = sparkSession.internalCreateDataFrame(data.queryExecution.toRdd, logicalRelation.schema) relation.insert(df, overwrite) - // Invalidate the cache. - sparkSession.sharedState.cacheManager.invalidateCache(logicalRelation) + // Re-cache all cached plans(including this relation itself, if it's cached) that refer to this + // data source relation. + sparkSession.sharedState.cacheManager.recacheByPlan(sparkSession, logicalRelation) Seq.empty[Row] } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala index 73e6abc6dad37..47567032b0195 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/csv/CSVDataSource.scala @@ -133,20 +133,24 @@ object TextInputCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: Dataset[String] = createBaseDataset(sparkSession, inputPaths, parsedOptions) - val firstLine: String = CSVUtils.filterCommentAndEmpty(csv, parsedOptions).first() - val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.rdd.mapPartitions { iter => - val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) - val linesWithoutHeader = - CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) - val parser = new CsvParser(parsedOptions.asParserSettings) - linesWithoutHeader.map(parser.parseLine) + val csv = createBaseDataset(sparkSession, inputPaths, parsedOptions) + CSVUtils.filterCommentAndEmpty(csv, parsedOptions).take(1).headOption match { + case Some(firstLine) => + val firstRow = new CsvParser(parsedOptions.asParserSettings).parseLine(firstLine) + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.rdd.mapPartitions { iter => + val filteredLines = CSVUtils.filterCommentAndEmpty(iter, parsedOptions) + val linesWithoutHeader = + CSVUtils.filterHeaderLine(filteredLines, firstLine, parsedOptions) + val parser = new CsvParser(parsedOptions.asParserSettings) + linesWithoutHeader.map(parser.parseLine) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first line could not be read, just return the empty schema. + Some(StructType(Nil)) } - - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) } private def createBaseDataset( @@ -190,28 +194,28 @@ object WholeFileCSVDataSource extends CSVDataSource { sparkSession: SparkSession, inputPaths: Seq[FileStatus], parsedOptions: CSVOptions): Option[StructType] = { - val csv: RDD[PortableDataStream] = createBaseRdd(sparkSession, inputPaths, parsedOptions) - val maybeFirstRow: Option[Array[String]] = csv.flatMap { lines => + val csv = createBaseRdd(sparkSession, inputPaths, parsedOptions) + csv.flatMap { lines => UnivocityParser.tokenizeStream( CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - false, + shouldDropHeader = false, new CsvParser(parsedOptions.asParserSettings)) - }.take(1).headOption - - if (maybeFirstRow.isDefined) { - val firstRow = maybeFirstRow.get - val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis - val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) - val tokenRDD = csv.flatMap { lines => - UnivocityParser.tokenizeStream( - CodecStreams.createInputStreamWithCloseResource(lines.getConfiguration, lines.getPath()), - parsedOptions.headerFlag, - new CsvParser(parsedOptions.asParserSettings)) - } - Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) - } else { - // If the first row could not be read, just return the empty schema. - Some(StructType(Nil)) + }.take(1).headOption match { + case Some(firstRow) => + val caseSensitive = sparkSession.sessionState.conf.caseSensitiveAnalysis + val header = makeSafeHeader(firstRow, caseSensitive, parsedOptions) + val tokenRDD = csv.flatMap { lines => + UnivocityParser.tokenizeStream( + CodecStreams.createInputStreamWithCloseResource( + lines.getConfiguration, + lines.getPath()), + parsedOptions.headerFlag, + new CsvParser(parsedOptions.asParserSettings)) + } + Some(CSVInferSchema.infer(tokenRDD, header, parsedOptions)) + case None => + // If the first row could not be read, just return the empty schema. + Some(StructType(Nil)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala index d92529748b6ac..cbf656a2044dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/statefulOperators.scala @@ -68,7 +68,7 @@ trait StateStoreWriter extends StatefulOperator { } /** An operator that supports watermark. */ -trait WatermarkSupport extends SparkPlan { +trait WatermarkSupport extends UnaryExecNode { /** The keys that may have a watermark attribute. */ def keyExpressions: Seq[Attribute] @@ -76,8 +76,8 @@ trait WatermarkSupport extends SparkPlan { /** The watermark value. */ def eventTimeWatermark: Option[Long] - /** Generate a predicate that matches data older than the watermark */ - lazy val watermarkPredicate: Option[Predicate] = { + /** Generate an expression that matches data older than the watermark */ + lazy val watermarkExpression: Option[Expression] = { val optionalWatermarkAttribute = keyExpressions.find(_.metadata.contains(EventTimeWatermark.delayKey)) @@ -96,9 +96,19 @@ trait WatermarkSupport extends SparkPlan { } logInfo(s"Filtering state store on: $evictionExpression") - newPredicate(evictionExpression, keyExpressions) + evictionExpression } } + + /** Generate a predicate based on keys that matches data older than the watermark */ + lazy val watermarkPredicateForKeys: Option[Predicate] = + watermarkExpression.map(newPredicate(_, keyExpressions)) + + /** + * Generate a predicate based on the child output that matches data older than the watermark. + */ + lazy val watermarkPredicate: Option[Predicate] = + watermarkExpression.map(newPredicate(_, child.output)) } /** @@ -192,7 +202,7 @@ case class StateStoreSaveExec( } // Assumption: Append mode can be done only when watermark has been specified - store.remove(watermarkPredicate.get.eval _) + store.remove(watermarkPredicateForKeys.get.eval _) store.commit() numTotalStateRows += store.numKeys() @@ -215,7 +225,9 @@ case class StateStoreSaveExec( override def hasNext: Boolean = { if (!baseIterator.hasNext) { // Remove old aggregates if watermark specified - if (watermarkPredicate.nonEmpty) store.remove(watermarkPredicate.get.eval _) + if (watermarkPredicateForKeys.nonEmpty) { + store.remove(watermarkPredicateForKeys.get.eval _) + } store.commit() numTotalStateRows += store.numKeys() false @@ -361,7 +373,7 @@ case class StreamingDeduplicateExec( val numUpdatedStateRows = longMetric("numUpdatedStateRows") val baseIterator = watermarkPredicate match { - case Some(predicate) => iter.filter((row: InternalRow) => !predicate.eval(row)) + case Some(predicate) => iter.filter(row => !predicate.eval(row)) case None => iter } @@ -381,7 +393,7 @@ case class StreamingDeduplicateExec( } CompletionIterator[InternalRow, Iterator[InternalRow]](result, { - watermarkPredicate.foreach(f => store.remove(f.eval _)) + watermarkPredicateForKeys.foreach(f => store.remove(f.eval _)) store.commit() numTotalStateRows += store.numKeys() }) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala index ed07ff3ff0599..53374859f13f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/CatalogImpl.scala @@ -343,8 +343,8 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def dropTempView(viewName: String): Boolean = { - sparkSession.sessionState.catalog.getTempView(viewName).exists { tempView => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, tempView)) + sparkSession.sessionState.catalog.getTempView(viewName).exists { viewDef => + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropTempView(viewName) } } @@ -359,7 +359,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { */ override def dropGlobalTempView(viewName: String): Boolean = { sparkSession.sessionState.catalog.getGlobalTempView(viewName).exists { viewDef => - sparkSession.sharedState.cacheManager.uncacheQuery(Dataset.ofRows(sparkSession, viewDef)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession, viewDef, blocking = true) sessionCatalog.dropGlobalTempView(viewName) } } @@ -404,7 +404,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def uncacheTable(tableName: String): Unit = { - sparkSession.sharedState.cacheManager.uncacheQuery(query = sparkSession.table(tableName)) + sparkSession.sharedState.cacheManager.uncacheQuery(sparkSession.table(tableName)) } /** @@ -442,17 +442,12 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { // If this table is cached as an InMemoryRelation, drop the original // cached version and make the new version cached lazily. - val logicalPlan = sparkSession.table(tableIdent).queryExecution.analyzed - // Use lookupCachedData directly since RefreshTable also takes databaseName. - val isCached = sparkSession.sharedState.cacheManager.lookupCachedData(logicalPlan).nonEmpty - if (isCached) { - // Create a data frame to represent the table. - // TODO: Use uncacheTable once it supports database name. - val df = Dataset.ofRows(sparkSession, logicalPlan) + val table = sparkSession.table(tableIdent) + if (isCached(table)) { // Uncache the logicalPlan. - sparkSession.sharedState.cacheManager.uncacheQuery(df, blocking = true) + sparkSession.sharedState.cacheManager.uncacheQuery(table, blocking = true) // Cache it again. - sparkSession.sharedState.cacheManager.cacheQuery(df, Some(tableIdent.table)) + sparkSession.sharedState.cacheManager.cacheQuery(table, Some(tableIdent.table)) } } @@ -464,7 +459,7 @@ class CatalogImpl(sparkSession: SparkSession) extends Catalog { * @since 2.0.0 */ override def refreshByPath(resourcePath: String): Unit = { - sparkSession.sharedState.cacheManager.invalidateCachedPath(sparkSession, resourcePath) + sparkSession.sharedState.cacheManager.recacheByPath(sparkSession, resourcePath) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 2d81851565aa1..1244f690fd829 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -668,6 +668,18 @@ object SQLConf { .booleanConf .createWithDefault(false) + val JOIN_REORDER_ENABLED = + buildConf("spark.sql.cbo.joinReorder.enabled") + .doc("Enables join reorder in CBO.") + .booleanConf + .createWithDefault(false) + + val JOIN_REORDER_DP_THRESHOLD = + buildConf("spark.sql.cbo.joinReorder.dp.threshold") + .doc("The maximum number of joined nodes allowed in the dynamic programming algorithm.") + .intConf + .createWithDefault(12) + val SESSION_LOCAL_TIMEZONE = buildConf("spark.sql.session.timeZone") .doc("""The ID of session local timezone, e.g. "GMT", "America/Los_Angeles", etc.""") @@ -677,6 +689,10 @@ object SQLConf { object Deprecated { val MAPRED_REDUCE_TASKS = "mapred.reduce.tasks" } + + object Replaced { + val MAPREDUCE_JOB_REDUCES = "mapreduce.job.reduces" + } } /** @@ -881,6 +897,10 @@ private[sql] class SQLConf extends Serializable with CatalystConf with Logging { override def cboEnabled: Boolean = getConf(SQLConf.CBO_ENABLED) + override def joinReorderEnabled: Boolean = getConf(SQLConf.JOIN_REORDER_ENABLED) + + override def joinReorderDPThreshold: Int = getConf(SQLConf.JOIN_REORDER_DP_THRESHOLD) + /** ********************** SQLConf functionality methods ************ */ /** Set Spark SQL configuration properties. */ diff --git a/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql new file mode 100644 index 0000000000000..9308560451bf5 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/inputs/json-functions.sql @@ -0,0 +1,8 @@ +-- to_json +describe function to_json; +describe function extended to_json; +select to_json(named_struct('a', 1, 'b', 2)); +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); +-- Check if errors handled +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')); +select to_json(); diff --git a/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out new file mode 100644 index 0000000000000..d8aa4fb9fa788 --- /dev/null +++ b/sql/core/src/test/resources/sql-tests/results/json-functions.sql.out @@ -0,0 +1,63 @@ +-- Automatically generated by SQLQueryTestSuite +-- Number of queries: 6 + + +-- !query 0 +describe function to_json +-- !query 0 schema +struct +-- !query 0 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 1 +describe function extended to_json +-- !query 1 schema +struct +-- !query 1 output +Class: org.apache.spark.sql.catalyst.expressions.StructToJson +Extended Usage: + Examples: + > SELECT to_json(named_struct('a', 1, 'b', 2)); + {"a":1,"b":2} + > SELECT to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')); + {"time":"26/08/2015"} + +Function: to_json +Usage: to_json(expr[, options]) - Returns a json string with a given struct value + + +-- !query 2 +select to_json(named_struct('a', 1, 'b', 2)) +-- !query 2 schema +struct +-- !query 2 output +{"a":1,"b":2} + + +-- !query 3 +select to_json(named_struct('time', to_timestamp('2015-08-26', 'yyyy-MM-dd')), map('timestampFormat', 'dd/MM/yyyy')) +-- !query 3 schema +struct +-- !query 3 output +{"time":"26/08/2015"} + + +-- !query 4 +select to_json(named_struct('a', 1, 'b', 2), named_struct('mode', 'PERMISSIVE')) +-- !query 4 schema +struct<> +-- !query 4 output +org.apache.spark.sql.AnalysisException +Must use a map() function for options;; line 1 pos 7 + + +-- !query 5 +select to_json() +-- !query 5 schema +struct<> +-- !query 5 output +org.apache.spark.sql.AnalysisException +Invalid number of arguments for function to_json; line 1 pos 7 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 2a0e088437fda..7a7d52b21427a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -24,15 +24,15 @@ import scala.language.postfixOps import org.scalatest.concurrent.Eventually._ import org.apache.spark.CleanerListener -import org.apache.spark.sql.catalyst.expressions.{Expression, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.catalyst.expressions.SubqueryExpression import org.apache.spark.sql.execution.RDDScanExec import org.apache.spark.sql.execution.columnar._ import org.apache.spark.sql.execution.exchange.ShuffleExchange import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils} import org.apache.spark.storage.{RDDBlockId, StorageLevel} -import org.apache.spark.util.AccumulatorContext +import org.apache.spark.util.{AccumulatorContext, Utils} private case class BigData(s: String) @@ -65,7 +65,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext maybeBlock.nonEmpty } - private def getNumInMemoryRelations(plan: LogicalPlan): Int = { + private def getNumInMemoryRelations(ds: Dataset[_]): Int = { + val plan = ds.queryExecution.withCachedData var sum = plan.collect { case _: InMemoryRelation => 1 }.sum plan.transformAllExpressions { case e: SubqueryExpression => @@ -187,7 +188,7 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext assertCached(spark.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - getNumInMemoryRelations(spark.table("testData").queryExecution.withCachedData) + getNumInMemoryRelations(spark.table("testData")) } spark.catalog.cacheTable("testData") @@ -580,21 +581,21 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext localRelation.createOrReplaceTempView("localRelation") spark.catalog.cacheTable("localRelation") - assert(getNumInMemoryRelations(localRelation.queryExecution.withCachedData) == 1) + assert(getNumInMemoryRelations(localRelation) == 1) } test("SPARK-19093 Caching in side subquery") { withTempView("t1") { Seq(1).toDF("c1").createOrReplaceTempView("t1") spark.catalog.cacheTable("t1") - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |NOT EXISTS (SELECT * FROM t1) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 2) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 2) } } @@ -610,17 +611,17 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext spark.catalog.cacheTable("t4") // Nested predicate subquery - val cachedPlan = + val ds = sql( """ |SELECT * FROM t1 |WHERE |c1 IN (SELECT c1 FROM t2 WHERE c1 IN (SELECT c1 FROM t3 WHERE c1 = 1)) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan) == 3) + """.stripMargin) + assert(getNumInMemoryRelations(ds) == 3) // Scalar subquery and predicate subquery - val cachedPlan2 = + val ds2 = sql( """ |SELECT * FROM (SELECT max(c1) FROM t1 GROUP BY c1) @@ -630,8 +631,27 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSQLContext |EXISTS (SELECT c1 FROM t3) |OR |c1 IN (SELECT c1 FROM t4) - """.stripMargin).queryExecution.optimizedPlan - assert(getNumInMemoryRelations(cachedPlan2) == 4) + """.stripMargin) + assert(getNumInMemoryRelations(ds2) == 4) + } + } + + test("SPARK-19765: UNCACHE TABLE should un-cache all cached plans that refer to this table") { + withTable("t") { + withTempPath { path => + Seq(1 -> "a").toDF("i", "j").write.parquet(path.getCanonicalPath) + sql(s"CREATE TABLE t USING parquet LOCATION '$path'") + spark.catalog.cacheTable("t") + spark.table("t").select($"i").cache() + checkAnswer(spark.table("t").select($"i"), Row(1)) + assertCached(spark.table("t").select($"i")) + + Utils.deleteRecursively(path) + spark.sessionState.catalog.refreshTable(TableIdentifier("t")) + spark.catalog.uncacheTable("t") + assert(spark.table("t").select($"i").count() == 0) + assert(getNumInMemoryRelations(spark.table("t").select($"i")) == 0) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala index 953d161ec2a1d..cdea3b9a0f79f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala @@ -197,4 +197,27 @@ class JsonFunctionsSuite extends QueryTest with SharedSQLContext { .select(to_json($"struct").as("json")) checkAnswer(dfTwo, readBackTwo) } + + test("SPARK-19637 Support to_json in SQL") { + val df1 = Seq(Tuple1(Tuple1(1))).toDF("a") + checkAnswer( + df1.selectExpr("to_json(a)"), + Row("""{"_1":1}""") :: Nil) + + val df2 = Seq(Tuple1(Tuple1(java.sql.Timestamp.valueOf("2015-08-26 18:00:00.0")))).toDF("a") + checkAnswer( + df2.selectExpr("to_json(a, map('timestampFormat', 'dd/MM/yyyy HH:mm'))"), + Row("""{"_1":"26/08/2015 18:00"}""") :: Nil) + + val errMsg1 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, named_struct('a', 1))") + } + assert(errMsg1.getMessage.startsWith("Must use a map() function for options")) + + val errMsg2 = intercept[AnalysisException] { + df2.selectExpr("to_json(a, map('a', 1))") + } + assert(errMsg2.getMessage.startsWith( + "A type of keys and values in map() must be string, but got")) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 468ea0551298e..d9e0196c57957 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -1019,6 +1019,18 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext { spark.sessionState.conf.clear() } + test("SET mapreduce.job.reduces automatically converted to spark.sql.shuffle.partitions") { + spark.sessionState.conf.clear() + val before = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + val newConf = before + 1 + sql(s"SET mapreduce.job.reduces=${newConf.toString}") + val after = spark.conf.get(SQLConf.SHUFFLE_PARTITIONS.key).toInt + assert(before != after) + assert(newConf === after) + intercept[IllegalArgumentException](sql(s"SET mapreduce.job.reduces=-1")) + spark.sessionState.conf.clear() + } + test("apply schema") { val schema1 = StructType( StructField("f1", IntegerType, false) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala index bbb31dbc8f3de..1f547c5a2a8ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/StatisticsCollectionSuite.scala @@ -112,30 +112,6 @@ class StatisticsCollectionSuite extends StatisticsCollectionTestBase with Shared spark.sessionState.conf.autoBroadcastJoinThreshold) } - test("estimates the size of limit") { - withTempView("test") { - Seq(("one", 1), ("two", 2), ("three", 3), ("four", 4)).toDF("k", "v") - .createOrReplaceTempView("test") - Seq((0, 1), (1, 24), (2, 48)).foreach { case (limit, expected) => - val df = sql(s"""SELECT * FROM test limit $limit""") - - val sizesGlobalLimit = df.queryExecution.analyzed.collect { case g: GlobalLimit => - g.stats(conf).sizeInBytes - } - assert(sizesGlobalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesGlobalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesGlobalLimit.head}") - - val sizesLocalLimit = df.queryExecution.analyzed.collect { case l: LocalLimit => - l.stats(conf).sizeInBytes - } - assert(sizesLocalLimit.size === 1, s"Size wrong for:\n ${df.queryExecution}") - assert(sizesLocalLimit.head === BigInt(expected), - s"expected exact size $expected for table 'test', got: ${sizesLocalLimit.head}") - } - } - } - test("column stats round trip serialization") { // Make sure we serialize and then deserialize and we will get the result data val df = data.toDF(stats.keys.toSeq :+ "carray" : _*) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 0bfc92fdb6218..02ccebd22bdf9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -242,11 +242,12 @@ class PlannerSuite extends SharedSQLContext { val doubleRepartitioned = testData.repartition(10).repartition(20).coalesce(5) def countRepartitions(plan: LogicalPlan): Int = plan.collect { case r: Repartition => r }.length assert(countRepartitions(doubleRepartitioned.queryExecution.logical) === 3) - assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 1) + assert(countRepartitions(doubleRepartitioned.queryExecution.optimizedPlan) === 2) doubleRepartitioned.queryExecution.optimizedPlan match { - case r: Repartition => - assert(r.numPartitions === 5) - assert(r.shuffle === false) + case Repartition (numPartitions, shuffle, Repartition(_, shuffleChild, _)) => + assert(numPartitions === 5) + assert(shuffle === false) + assert(shuffleChild === true) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala index 2d95cb6d64a87..2ca2206bb9d44 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SQLViewSuite.scala @@ -172,7 +172,7 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { var e = intercept[AnalysisException] { sql(s"INSERT INTO TABLE $viewName SELECT 1") }.getMessage - assert(e.contains("Inserting into an RDD-based table is not allowed")) + assert(e.contains("Inserting into a view is not allowed. View: `default`.`testview`")) val dataFilePath = Thread.currentThread().getContextClassLoader.getResource("data/files/employee.dat") @@ -609,12 +609,39 @@ abstract class SQLViewSuite extends QueryTest with SQLTestUtils { } } - // TODO: Check for cyclic view references on ALTER VIEW. - ignore("correctly handle a cyclic view reference") { - withView("view1", "view2") { + test("correctly handle a cyclic view reference") { + withView("view1", "view2", "view3") { sql("CREATE VIEW view1 AS SELECT * FROM jt") sql("CREATE VIEW view2 AS SELECT * FROM view1") - intercept[AnalysisException](sql("ALTER VIEW view1 AS SELECT * FROM view2")) + sql("CREATE VIEW view3 AS SELECT * FROM view2") + + // Detect cyclic view reference on ALTER VIEW. + val e1 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e1.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect the most left cycle when there exists multiple cyclic view references. + val e2 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM view3 JOIN view2") + }.getMessage + assert(e2.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view3` -> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference on CREATE OR REPLACE VIEW. + val e3 = intercept[AnalysisException] { + sql("CREATE OR REPLACE VIEW view1 AS SELECT * FROM view2") + }.getMessage + assert(e3.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) + + // Detect cyclic view reference from subqueries. + val e4 = intercept[AnalysisException] { + sql("ALTER VIEW view1 AS SELECT * FROM jt WHERE EXISTS (SELECT 1 FROM view2)") + }.getMessage + assert(e4.contains("Recursive view `default`.`view1` detected (cycle: `default`.`view1` " + + "-> `default`.`view2` -> `default`.`view1`)")) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala index bb6c486e880a0..a4d012cd76115 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkSqlParserSuite.scala @@ -45,7 +45,7 @@ class SparkSqlParserSuite extends PlanTest { * Normalizes plans: * - CreateTable the createTime in tableDesc will replaced by -1L. */ - private def normalizePlan(plan: LogicalPlan): LogicalPlan = { + override def normalizePlan(plan: LogicalPlan): LogicalPlan = { plan match { case CreateTable(tableDesc, mode, query) => val newTableDesc = tableDesc.copy(createTime = -1L) @@ -210,6 +210,17 @@ class SparkSqlParserSuite extends PlanTest { "no viable alternative at input") } + test("create view as insert into table") { + // Single insert query + intercept("CREATE VIEW testView AS INSERT INTO jt VALUES(1, 1)", + "Operation not allowed: CREATE VIEW ... AS INSERT INTO") + + // Multi insert query + intercept("CREATE VIEW testView AS FROM jt INSERT INTO tbl1 SELECT * WHERE jt.id < 5 " + + "INSERT INTO tbl2 SELECT * WHERE jt.id > 4", + "Operation not allowed: CREATE VIEW ... AS FROM ... [INSERT INTO ...]+") + } + test("SPARK-17328 Fix NPE with EXPLAIN DESCRIBE TABLE") { assertEqual("describe table t", DescribeTableCommand( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala index 6ffa58bcd9af1..b2199fdf90e5c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DDLSuite.scala @@ -1995,6 +1995,29 @@ class DDLSuite extends QueryTest with SharedSQLContext with BeforeAndAfterEach { } } + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"data source table:partition column name containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING parquet + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + } + } + } + } + Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"location uri contains $specialChars for datasource table") { withTable("t", "t1") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala index 56071803f685f..eaedede349134 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/csv/CSVSuite.scala @@ -1077,14 +1077,12 @@ class CSVSuite extends QueryTest with SharedSQLContext with SQLTestUtils { } } - test("Empty file produces empty dataframe with empty schema - wholeFile option") { - withTempPath { path => - path.createNewFile() - + test("Empty file produces empty dataframe with empty schema") { + Seq(false, true).foreach { wholeFile => val df = spark.read.format("csv") .option("header", true) - .option("wholeFile", true) - .load(path.getAbsolutePath) + .option("wholeFile", wholeFile) + .load(testFile(emptyFile)) assert(df.schema === spark.emptyDataFrame.schema) checkAnswer(df, spark.emptyDataFrame) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 9c55357ab9bc1..26c45e092dc65 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -22,15 +22,12 @@ import scala.reflect.ClassTag import org.apache.spark.AccumulatorSuite import org.apache.spark.sql.{Dataset, QueryTest, Row, SparkSession} import org.apache.spark.sql.catalyst.expressions.{BitwiseAnd, BitwiseOr, Cast, Literal, ShiftLeft} -import org.apache.spark.sql.catalyst.plans.logical.SubqueryAlias -import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.exchange.EnsureRequirements import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.functions._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types.{LongType, ShortType} -import org.apache.spark.util.Utils /** * Test various broadcast join operators. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala index 7ea716231e5dc..a15c2cff930fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/DeduplicateSuite.scala @@ -249,4 +249,23 @@ class DeduplicateSuite extends StateStoreMetricsTest with BeforeAndAfterAll { } } } + + test("SPARK-19841: watermarkPredicate should filter based on keys") { + val input = MemoryStream[(Int, Int)] + val df = input.toDS.toDF("time", "id") + .withColumn("time", $"time".cast("timestamp")) + .withWatermark("time", "1 second") + .dropDuplicates("id", "time") // Change the column positions + .select($"id") + testStream(df)( + AddData(input, 1 -> 1, 1 -> 1, 1 -> 2), + CheckLastBatch(1, 2), + AddData(input, 1 -> 1, 2 -> 3, 2 -> 4), + CheckLastBatch(3, 4), + AddData(input, 1 -> 0, 1 -> 1, 3 -> 5, 3 -> 6), // Drop (1 -> 0, 1 -> 1) due to watermark + CheckLastBatch(5, 6), + AddData(input, 1 -> 0, 4 -> 7), // Drop (1 -> 0) due to watermark + CheckLastBatch(7) + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala index c34d119734cc0..c768525bc6855 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/EventTimeWatermarkSuite.scala @@ -25,6 +25,7 @@ import org.scalatest.BeforeAndAfter import org.apache.spark.internal.Logging import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.plans.logical.EventTimeWatermark import org.apache.spark.sql.execution.streaming._ import org.apache.spark.sql.functions.{count, window} import org.apache.spark.sql.streaming.OutputMode._ @@ -305,6 +306,19 @@ class EventTimeWatermarkSuite extends StreamTest with BeforeAndAfter with Loggin ) } + test("the new watermark should override the old one") { + val df = MemoryStream[(Long, Long)].toDF() + .withColumn("first", $"_1".cast("timestamp")) + .withColumn("second", $"_2".cast("timestamp")) + .withWatermark("first", "1 minute") + .withWatermark("second", "2 minutes") + + val eventTimeColumns = df.logicalPlan.output + .filter(_.metadata.contains(EventTimeWatermark.delayKey)) + assert(eventTimeColumns.size === 1) + assert(eventTimeColumns(0).name === "second") + } + private def assertNumStateRows(numTotalRows: Long): AssertOnQuery = AssertOnQuery { q => val progressWithData = q.recentProgress.filter(_.numInputRows > 0).lastOption.get assert(progressWithData.stateOperators(0).numRowsTotal === numTotalRows) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala index 3c57ee4c8b8f6..b8536d0c1bd58 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/InsertIntoHiveTable.scala @@ -393,8 +393,8 @@ case class InsertIntoHiveTable( logWarning(s"Unable to delete staging directory: $stagingDir.\n" + e) } - // Invalidate the cache. - sparkSession.catalog.uncacheTable(table.qualifiedName) + // un-cache this table. + sparkSession.catalog.uncacheTable(table.identifier.quotedString) sparkSession.sessionState.catalog.refreshTable(table.identifier) // It would be nice to just return the childRdd unchanged so insert operations could be chained, diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala index 8ccc2b7527f24..2b3f36064c1f8 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/CachedTableSuite.scala @@ -195,10 +195,8 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with TestHiveSingleto tempPath.delete() table("src").write.mode(SaveMode.Overwrite).parquet(tempPath.toString) sql("DROP TABLE IF EXISTS refreshTable") - sparkSession.catalog.createExternalTable("refreshTable", tempPath.toString, "parquet") - checkAnswer( - table("refreshTable"), - table("src").collect()) + sparkSession.catalog.createTable("refreshTable", tempPath.toString, "parquet") + checkAnswer(table("refreshTable"), table("src")) // Cache the table. sql("CACHE TABLE refreshTable") assertCached(table("refreshTable")) diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala index 892a22ddfafc8..cf552b4a88b2c 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala @@ -64,7 +64,7 @@ class HiveMetastoreCatalogSuite extends TestHiveSingleton with SQLTestUtils { spark.sql("create view vw1 as select 1 as id") val plan = spark.sql("select id from vw1").queryExecution.analyzed val aliases = plan.collect { - case x @ SubqueryAlias("vw1", _, Some(TableIdentifier("vw1", Some("default")))) => x + case x @ SubqueryAlias("vw1", _) => x } assert(aliases.size == 1) } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index e956c9abae514..df2c1cee942b0 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -27,7 +27,7 @@ import org.scalatest.BeforeAndAfterEach import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SaveMode} import org.apache.spark.sql.catalyst.analysis.{NoSuchPartitionException, TableAlreadyExistsException} -import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils} +import org.apache.spark.sql.catalyst.catalog.{CatalogDatabase, CatalogTable, CatalogTableType, CatalogUtils, ExternalCatalogUtils} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.execution.command.DDLUtils import org.apache.spark.sql.hive.HiveExternalCatalog @@ -1690,6 +1690,39 @@ class HiveDDLSuite } } + Seq("parquet", "hive").foreach { datasource => + Seq("a b", "a:b", "a%b", "a,b").foreach { specialChars => + test(s"partition column name of $datasource table containing $specialChars") { + withTable("t") { + withTempDir { dir => + spark.sql( + s""" + |CREATE TABLE t(a string, `$specialChars` string) + |USING $datasource + |PARTITIONED BY(`$specialChars`) + |LOCATION '$dir' + """.stripMargin) + + assert(dir.listFiles().isEmpty) + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`=2) SELECT 1") + val partEscaped = s"${ExternalCatalogUtils.escapePathName(specialChars)}=2" + val partFile = new File(dir, partEscaped) + assert(partFile.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Nil) + + withSQLConf("hive.exec.dynamic.partition.mode" -> "nonstrict") { + spark.sql(s"INSERT INTO TABLE t PARTITION(`$specialChars`) SELECT 3, 4") + val partEscaped1 = s"${ExternalCatalogUtils.escapePathName(specialChars)}=4" + val partFile1 = new File(dir, partEscaped1) + assert(partFile1.listFiles().length >= 1) + checkAnswer(spark.table("t"), Row("1", "2") :: Row("3", "4") :: Nil) + } + } + } + } + } + } + Seq("a b", "a:b", "a%b").foreach { specialChars => test(s"datasource table: location uri contains $specialChars") { withTable("t", "t1") { diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala index be9a5fd71bd25..236135dcff523 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala @@ -1030,7 +1030,7 @@ class SQLQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { withSQLConf(SQLConf.CONVERT_CTAS.key -> "false") { sql("CREATE TABLE explodeTest (key bigInt)") table("explodeTest").queryExecution.analyzed match { - case SubqueryAlias(_, r: CatalogRelation, _) => // OK + case SubqueryAlias(_, r: CatalogRelation) => // OK case _ => fail("To correctly test the fix of SPARK-5875, explodeTest should be a MetastoreRelation") } diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala index 3512c4a890313..81af24979d822 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala @@ -453,7 +453,7 @@ class ParquetMetastoreSuite extends ParquetPartitioningTest { // Converted test_parquet should be cached. sessionState.catalog.getCachedDataSourceTable(tableIdentifier) match { case null => fail("Converted test_parquet should be cached in the cache.") - case logical @ LogicalRelation(parquetRelation: HadoopFsRelation, _, _) => // OK + case LogicalRelation(_: HadoopFsRelation, _, _) => // OK case other => fail( "The cached test_parquet should be a Parquet Relation. " +